AlphaFold Ilustrado
(elanapearl.github.io)- O AlphaFold3 vai além de uma única proteína e tenta prever apenas a partir da sequência complexos com proteínas, ácidos nucleicos e pequenas moléculas juntos; por isso, a representação de entrada e a tokenização ficam muito mais complexas do que no AF2
- A entrada é dividida em representações single/pair no nível de token, representações no nível atômico, MSA e templates; aminoácidos e nucleotídeos padrão são tratados como 1 token, enquanto resíduos não padrão e outras moléculas são tratados como 1 token por átomo
- O trunk de aprendizado de representação melhora repetidamente a representação single s e a representação pair z por meio do módulo de template, módulo de MSA e do Pairformer, com pair-bias attention, operações de triângulo e recycling
- A predição de estrutura usa um modelo de difusão condicional sobre coordenadas atômicas no lugar da Invariant Point Attention do AF2, gerando atualizações das coordenadas de todos os átomos com augmentação de rotação/translação e denoising
- O treinamento combina distogram, diffusion e confidence loss e reaprende até representações unfolded em regiões de baixa confiança por meio de cross-distillation com resultados do AF2 e do AF-Multimer
Escopo de entrada e pipeline geral do AlphaFold3
- O objetivo do AlphaFold3 não é apenas prever sequências individuais de proteínas como o AF2, nem lidar apenas com complexos proteicos como o AF-Multimer, mas prever apenas a partir da sequência estruturas em que proteínas se ligam opcionalmente a outras proteínas, ácidos nucleicos e pequenas moléculas
- O significado de “token” muda conforme o tipo de entrada
- Proteína: 1 aminoácido padrão = 1 token
- DNA/RNA: 1 nucleotídeo padrão = 1 token
- Aminoácidos e nucleotídeos não padrão: 1 átomo = 1 token
- Outras moléculas: 1 átomo = 1 token
- Uma proteína com 35 aminoácidos padrão pode de fato ter mais de 600 átomos, mas é representada por 35 tokens, enquanto um ligand com 35 átomos é representado por 35 tokens
- O modelo é composto, em linhas gerais, por três etapas
- Input Preparation: converte a sequência fornecida pelo usuário e as sequências/estruturas relacionadas encontradas em tensores numéricos
- Representation Learning: atualiza a representação single e a representação pair com várias variações de attention
- Structure Prediction: prevê a estrutura por difusão condicional
- O complexo proteico é armazenado principalmente em duas representações
- single representation: representa todos os tokens do próprio complexo
- pair representation: representa relações como distância e possíveis interações entre todos os pares de tokens
- As principais dimensões de canal são
c_z=128,c_m=64,c_atom=128,c_atompair=16,c_token=768,c_s=384
Preparação da entrada: o processo de transformar a sequência em 6 tensores
- A entrada fornecida pelo usuário é convertida em 6 tensores que entram no trunk do modelo
- s: token-level single representation
- z: token-level pair representation
- q: atom-level single representation
- p: atom-level pair representation
- m: MSA representation
- t: template representation
-
Busca de MSA e templates
- O AF3 busca sequências semelhantes para proteínas e sequências de RNA, monta isso como MSA e inclui estruturas relacionadas como template
- O MSA alinha sequências de proteínas semelhantes encontradas em várias espécies, fornecendo ao modelo padrões de conservação em posições específicas e correlações de mudança entre posições diferentes
- Estruturas conhecidas de proteínas semelhantes são usadas para estimar a estrutura da proteína query, de forma parecida com homology modeling
- A busca não inclui treinamento, e são usados métodos baseados em HMM
jackhmmer,HHBlitsenhmmersão usados para pesquisar vários bancos de dados de proteínas e RNA, ehmmsearché usado para encontrar sequências semelhantes no Protein Data Bank- O tamanho do MSA é limitado a
N_MSA < 2^14por causa da complexidade computacional - Em cada chain de proteína, são selecionadas estruturas de alta qualidade, com no máximo 4 templates amostrados
- Em comparação com o AF-Multimer, o novo elemento adicionado à busca é que sequências de RNA também entram no escopo da busca
-
Forma de representação dos templates
- Na estrutura 3D do template, calcula-se a distância euclidiana entre cada par de tokens
- Para tokens com vários átomos, usa-se um “center atom” representativo
- Aminoácido: átomo
Cα - Nucleotídeo padrão: átomo
C1'
- Aminoácido: átomo
- O valor de distância não é contínuo, mas discretizado como distogram
- 38 bins de 3.15Å a 50.75Å
- 1 bin adicional para distâncias maiores do que isso
- Ao distogram são adicionadas informações de chain, se o token correspondente foi resolvido na crystal structure e informações de distância local dentro de cada aminoácido
- A matrix de template é mascarada para considerar apenas distâncias dentro da mesma chain, e a seleção de templates não busca obter informações de interação entre chains
Representação em nível atômico e o Atom Transformer
-
reference conformere representação em nível atômico- Para criar a representação single em nível atômico q, é calculado um reference conformer para cada aminoácido, nucleotídeo e ligante
- Um conformer é uma disposição 3D dos átomos de uma molécula, gerada por amostragem de rotações em torno de ligações simples
- Aminoácidos padrão usam conformers de baixa energia obtidos por lookup, e pequenas moléculas geram conformers 3D com o ETKDGv3 do RDKit
- A representação single em nível atômico c é formada combinando posições relativas do conformer, carga atômica, número atômico, identificadores e outros atributos
- c inicializa a representação pair em nível atômico p, e a máscara v é usada para conter apenas as distâncias entre átomos calculadas a partir do reference conformer
- q começa como uma cópia de c e depois é atualizada no Atom Transformer
-
Papel do Atom Transformer
- O Atom Transformer é um módulo que realiza atenção em nível atômico e atualiza q usando p e a representação original c
- c não é atualizada e é usada como uma conexão residual em direção à representação inicial
- A estrutura básica é semelhante à de um transformer, incluindo LayerNorm, atenção e transição MLP, mas cada etapa é ajustada por entradas adicionais de c e p
-
Adaptive LayerNorm
- A Adaptive LayerNorm gera
gammaebetaa partir de uma entrada auxiliar, em vez de aprendergammaebetafixos - No Atom Transformer, q é o alvo do reescalonamento, e os parâmetros de reescalonamento são previstos a partir da entrada auxiliar c
- A Adaptive LayerNorm gera
-
Attention with Pair Bias
- A atenção em nível atômico com pair bias é uma extensão da self-attention
- Query, key e value vêm todos da representação single q, mas após o produto escalar entre query e key, uma projeção linear da representação pair p é adicionada como bias
- Há fluxo de informação da representação pair para q, mas nesta etapa p não é atualizada com informação de q
- Um gate criado ao passar uma projeção adicional por uma sigmoid é multiplicado pelo resultado da atenção, controlando quais informações permanecem no fluxo residual
- Como o número de átomos pode ser muito maior que o número de tokens, é usada Sequence-local atom attention em vez de full attention
- Grupos locais de 32 átomos podem fazer attend a outros 128 átomos
-
Conditioned Gating e Transition
- O Conditioned Gating aplica aos dados um gate gerado a partir da matriz single original em nível atômico c
- O Conditioned Transition corresponde à MLP do transformer e recebe esse nome porque a Adaptive LayerNorm e o Conditional Gating dependem de c
- O AF3 usa SwiGLU no bloco de transition em vez de ReLU
- A transition baseada em ReLU no AF2 tem a estrutura de up-projection 4x, ReLU e down-projection
- No SwiGLU do AF3, uma não linearidade swish é aplicada a uma das duas up-projections, depois os resultados são multiplicados e então projetados de volta para baixo
Agregando representações atômicas em representações de token
- Como a etapa de aprendizado de representação passa a operar em nível de token, as representações em nível atômico são agregadas em representações em nível de token
- A representação em nível atômico é projetada para uma dimensão maior e, em seguida, é tirada a média dos átomos pertencentes ao mesmo token
- Essa agregação por média é aplicada quando vários átomos estão ligados a um token, como em aminoácidos padrão e nucleotídeos, enquanto entradas com 1 token por átomo são mantidas como estão
- Estatísticas obtidas da MSA também são combinadas à entrada single em nível de token
- tipo de aminoácido
- distribuição de aminoácidos da MSA naquela posição
- média de deleções daquele token
- Para tokens sem MSA, como átomos de ligantes, esses valores ficam em 0
- O s_inputs criado dessa forma passa por uma projeção para se tornar s_init, que é atualizado na etapa de aprendizado de representação
- A representação pair z_init é um tensor tridimensional que armazena relações para cada par de tokens, e cada z_i,j é um vetor de dimensão
c_z=128 - A inicialização de z_i,j soma projeções de s_i e s_j, relative positional encoding e informações de ligação entre tokens especificadas pelo usuário
Aprendizado de representação: Template, MSA e Pairformer
- O aprendizado de representação é o trunk que responde pela maior parte do cálculo do modelo, e seu objetivo é melhorar a representação single em nível de token s e a representação pair z
- A representação de sequência única não se refere apenas a uma sequência de proteína isolada, mas a uma sequência formada pela concatenação de todos os átomos ou tokens na estrutura
-
Módulo de Template
- Cada template passa por uma projeção linear e é somado à projeção linear da representação pair z
- A matriz combinada passa por uma Pairformer Stack
- Os resultados de vários templates são promediados e depois passam novamente por uma camada linear
- A camada linear final usa ReLU, um dos raros pontos em que o AF3 usa ReLU como não linearidade
-
Módulo de MSA
- O Módulo de MSA é muito semelhante ao Evoformer do AF2 e melhora simultaneamente a representação de MSA m e a representação pair z
- Em vez de usar todas as linhas da MSA, é feito um subsampling, e uma projeção da representação single é adicionada à MSA
- Outer Product Mean é a operação que insere informações da MSA na representação pair
- Para cada índice de token
i,j, é calculado o produto externo de m_s,i e m_s,j para todas as sequências evolutivas - Isso é promediado ao longo de toda a sequência, achatado e então projetado antes de ser somado a z_i,j
- É o único ponto do modelo em que informações entre sequências evolutivas são compartilhadas
- Para cada índice de token
- Row-wise gated self-attention using only pair bias atualiza a MSA usando a representação pair
- Em vez de criar scores de atenção com query e key, uma projeção matricial da representação pair z é usada como score de atenção entre tokens
- Como é aplicado independentemente a cada linha da MSA, nesta etapa não há compartilhamento de informação entre sequências evolutivas
- O módulo de MSA termina atualizando novamente a representação pair com triangle update e triangle attention
Pairformer e operações triangulares
- Depois de atualizar z com template e MSA, template e MSA não são mais usados, e apenas s e z entram no Pairformer
- O Pairformer gera os valores finais s_trunk e z_trunk por meio da repetição de 48 blocos
-
Intuição das operações triangulares
- triangle update e triangle attention são estruturas pensadas para incorporar ao modelo a intuição da desigualdade triangular
- Embora z_i,j do pair tensor não seja a própria distância física, ele contém a relação entre os tokens
iej, então as três relaçõesi-j,j-kei-ksão atualizadas para permanecerem consistentes entre si - A desigualdade triangular não é imposta diretamente dentro do modelo; ela é induzida pela forma de atualizar z_i,j observando todos os tripletos
(i,j,k) - z pode ser visto como uma matriz de adjacência direcionada, então as direções de outgoing edge e incoming edge são tratadas separadamente
-
Triangle Updates
- No outgoing update, cada z_i,j é atualizado usando outro elemento z_i,k da mesma linha e a terceira aresta z_j,k
- Na implementação, são criadas três projeções de z:
a,beg; depois, faz-se a multiplicação elemento a elemento entre a linhaie a linhaj, somando sobrek, e em seguida aplica-se o gateg - O incoming update tem a forma com linha e coluna trocadas: z_i,j é atualizado por meio de outros elementos z_k,j da mesma coluna e de z_k,i
-
Triangle Attention
- triangle attention é uma forma que adiciona o princípio triangular ao axial attention, que aplica attention independente às linhas e colunas de uma matriz 2D
- No caso “starting node”, a comparação query-key entre z_i,j e z_i,k recebe z_j,k como bias adicional
- No caso “ending node”, ele opera com base nas colunas, e o score de attention entre z_i,j e z_k,i recebe z_k,j como bias
-
Single Attention with Pair Bias
- Após o passo triangular e o transition block, a single representation s é atualizada por meio de single attention with pair bias usando a pair representation z já atualizada
- Como opera no nível de token, usa full attention em vez do block-wise sparse attention usado no nível de átomo
Previsão de estrutura: denoising de coordenadas atômicas por difusão
-
Modo básico do modelo de difusão
- O AF3 realiza a previsão final da estrutura com atom-level diffusion
- Um diffusion model adiciona random noise aos dados reais de forma gradual e treina o modelo para prever que tipo de noise foi adicionado
- Na inferência, ele começa com random noise completo e, a cada step, remove o noise previsto pelo modelo, gerando um datapoint com denoising
- A difusão condicional recebe como entrada a geração ruidosa atual, a representação do timestep atual e o vetor de condição, para produzir um resultado compatível com a condição
- No AF3, o alvo do denoising é a matriz x que contém as coordenadas
x,y,zde todos os átomos
-
Em vez do IPA do AF2, aumento por rotação e translação
- O AF3 não usa o Invariant Point Attention do AF2; em vez disso, aplica rotação e translação aleatórias ao complexo inteiro que está sendo previsto em cada timestep
- Esse aumento faz o modelo aprender que qualquer rotação ou translação continua válida para a mesma estrutura e é uma abordagem mais simples que o IPA do AF2
- A rotação é aplicada em torno da média das coordenadas de todos os átomos da geração atual, e a translação é amostrada como uma Gaussiana
N(0,1)em cada dimensão - Um pequeno noise também é adicionado às coordenadas para induzir gerações mais diversas
- Na inferência, várias gerações podem ser pontuadas pelo confidence head, e a geração com maior pontuação pode ser retornada
-
As quatro etapas do Diffusion Module
- Cada step de denoising usa várias representações de conditioning
- saídas do trunk s_trunk, z_trunk
- representações iniciais s_inputs, c_inputs criadas pelo input embedder
- O processo de difusão é composto por quatro etapas, alternando entre os espaços de token e de átomo
-
- preparar o tensor de conditioning no nível de token
-
- preparar o tensor de conditioning no nível de átomo, aplicar o Atom Transformer e agregá-lo ao nível de token
-
- aplicar attention no nível de token
-
- prever o noise update por átomo com attention no nível de átomo
-
- No conditioning em nível de token, z_trunk é combinado com o relative positional encoding e passado por um transition block
- À single representation, combinam-se s_inputs e s_trunk, e soma-se um Fourier embedding de acordo com o diffusion timestep
- Na etapa em nível de átomo, os valores iniciais c e p são atualizados com a representação atual em nível de token, e as coordenadas atuais x são escaladas pela variância dos dados para formar a coordenada adimensional r
- Na etapa final em nível de átomo, uma linear layer faz o mapping de q para
R^3, gerando o coordinate update r_update de todos os átomos - O update é então reescalado para x_update levando em conta a variância dos dados e o noise schedule, e aplicado às coordenadas atuais x_l
- Cada step de denoising usa várias representações de conditioning
Função de perda e confidence head
- A loss total é uma soma ponderada de três termos
L_loss = L_distogram * α_distogram + L_diffusion * α_diffusion + L_confidence * α_confidence
-
L_distogram
- L_distogram avalia a precisão do distograma previsto no nível de token
- Ao criar coordenadas de token a partir de coordenadas atômicas, são usadas as coordenadas do átomo central de cada token
- A distância do distograma é tratada como um valor categórico, e o distograma previsto é comparado ao distograma real com entropia cruzada
-
L_diffusion
- L_diffusion é uma soma ponderada de vários termos sobre posições atômicas
- L_MSE calcula o erro quadrático médio entre posições para todos os átomos, não apenas para o átomo central, e átomos de DNA, RNA e ligantes recebem peso maior
- L_bond é um termo adicional de MSE para aumentar a precisão do comprimento de ligação de pares de átomos incluídos em ligações proteína-ligante
- No estágio inicial de treinamento,
α_bond=0, então ele é introduzido depois - L_smooth_LDDT é uma loss que torna a precisão de distância local suave e diferenciável
- São usados quatro limiares: 4Å, 2Å, 1Å e 0.5Å
- Pares de átomos de nucleotídeos são ignorados se estiverem a mais de 30Å de distância
- Pares de átomos de proteína ou ligante são ignorados se estiverem a mais de 15Å de distância
-
L_confidence
- L_confidence não aumenta diretamente a precisão estrutural; em vez disso, treina o modelo para estimar a precisão de suas próprias previsões
- É composto por losses correspondentes a quatro métricas de confiança
- pLDDT: precisão de distância local para átomos próximos
- PAE: erro de alinhamento previsto para pares de tokens
- PDE: erro de distância previsto entre pares de tokens
- experimentally resolved prediction: previsão de se cada átomo foi resolvido na estrutura experimental
- Mesmo que a estrutura prevista seja imprecisa e o PAE seja alto, essa loss de PAE pode ser baixa se o modelo também prever um PAE alto
- A previsão de confiança é gerada em etapas intermediárias da diffusion
- O gradiente da confidence loss atualiza apenas a head de previsão de confiança e não afeta o restante do modelo
Técnicas adicionais de treinamento e eficiência
-
Recycling
- O AF3 usa weight recycling, assim como o AF2
- Em vez de tornar o modelo mais profundo, ele reutiliza os mesmos pesos várias vezes para melhorar gradualmente a representação
- A diffusion também usa informações de timestep na inferência e reutiliza os mesmos pesos a cada timestep, então já incorpora recycling
-
Cross-distillation
- O AF3 usa não apenas dados sintéticos de treinamento gerados por ele mesmo, mas também dados sintéticos criados por AF2 e AF-Multimer
- Após a mudança para geração baseada em diffusion, surgiu o problema de a forma “spaghetti”, que no AF2 permitia distinguir visualmente regiões de baixa confiança e desordenadas, desaparecer
- Ao incluir gerações de AF2 e AF-Multimer nos dados de treinamento do AF3, o AF3 aprende a produzir regiões unfolded nas áreas em que o AF2 não tinha confiança
- No conjunto de dados de distillation, ácidos nucleicos e pequenas moléculas que AF2 e AF-Multimer não conseguem processar são removidos
- Depois que o modelo anterior cria a estrutura prevista e ela é alinhada com a original, as moléculas removidas são adicionadas de volta
- Se as moléculas adicionadas de volta gerarem atom clash, a estrutura inteira é excluída para evitar que o modelo aprenda a permitir clashes
-
Cropping e estágios de treinamento
- O modelo em si não tem um limite explícito para o comprimento da sequência de entrada, mas várias operações crescem com
N_tokens^3, aumentando as exigências de memória e computação - Para eficiência, proteínas passam por random crop
- Como é preciso modelar interações entre várias chains, o crop deve incluir as chains em conjunto
- São usados três métodos de cropping
- contiguous cropping: seleção de sequências contíguas de aminoácidos em cada chain
- spatial cropping: seleção de aminoácidos com base na distância até um átomo de referência
- spatial interface cropping: seleção com base na distância até átomos da interface de ligação
- Um modelo treinado com random crop 384 também pode ser aplicado a sequências mais longas, mas, para aumentar a capacidade de lidar com sequências maiores, ele passa por fine-tuning repetido com comprimentos de sequência maiores
- O modelo em si não tem um limite explícito para o comprimento da sequência de entrada, mas várias operações crescem com
-
Clashing e tamanho de batch
- A loss do AF3 não inclui penalidade de clash para átomos sobrepostos
- Em teoria, o módulo estrutural baseado em diffusion pode prever dois átomos na mesma posição, mas, após o treinamento, esse problema é pequeno
- Uma penalidade de clashing é usada no ranking das estruturas geradas
- O processo de diffusion parece complexo, mas tem custo computacional menor que o trunk
- Para eficiência no treinamento, o tamanho do batch é ampliado após o trunk
- Cada estrutura de entrada passa uma vez por embedding e pelo trunk, e depois 48 estruturas independentes com data augmentation são treinadas em paralelo
Projeto do AF3 sob a perspectiva de ML
-
Estrutura semelhante a Retrieval-Augmented Generation
- A busca de MSA e de templates no AF3 tem características semelhantes ao RAG de modelos de linguagem
- Na área de AlphaFold, a abordagem de usar templates estruturais já vinha sendo usada há muito mais tempo como homology modeling do que com o termo RAG
- O AF3 reduziu, em relação ao AF2, o peso do processamento de MSA, mas ainda inclui MSA e templates
- Alguns modelos de predição de proteínas, como o ESMFold, removem o retrieval e usam fully parametric inference
-
Pair-Bias Attention
- O Pair-Bias Attention, que era um componente principal do AF2, é usado de forma mais ampla no AF3
- query, key e value vêm da mesma source, mas um termo de bias vindo de outra source é adicionado ao attention map
- Essa é uma forma mais leve de compartilhamento de informação do que full cross-attention
- Como a pair representation é naturalmente semelhante ao attention map, essa estrutura pode se encaixar bem na modelagem de proteínas
-
Redução do self-supervised training
- Modelos da linha ESM mostraram força ao substituir MSA embedding por self-supervised pre-training
- O AF2 tinha uma task adicional de prever masked tokens da MSA, mas isso foi removido no AF3
- O AF3 reduziu o compute dedicado ao processamento de MSA e não usa self-supervised language modeling pre-training para MSA
- Entre as possíveis razões estão: massive pre-training pode ter sido ineficiente em termos de uso de compute; um módulo pequeno de MSA pode ter sido melhor que embeddings pré-treinados; ou a combinação de embeddings pré-treinados com uma estrutura híbrida de atom-token misturando aminoácidos, DNA/RNA e ligands pode não ter funcionado bem
-
Mistura de Classification e Regression
- O AF3, assim como o AF2, usa MSE junto com binned classification loss
- Uma característica da classification loss é que, mesmo errando apenas um distogram bin, não há crédito, da mesma forma que em um erro muito distante
- A base para essa escolha de projeto não é clara, mas é possível que o gradient tenha sido mais estável do que com várias MSE loss
-
Elementos que lembram recurrent architecture
- O AF3 tem muitos elementos que lembram mais uma recurrent network do que um transformer comum
- O gating controla o fluxo de informação no residual stream e é semelhante aos gates de LSTM ou GRU
- recycling e diffusion aplicam repetidamente o mesmo weight para melhorar gradualmente a predição
- De forma semelhante a adaptive compute time, as atualizações iterativas se relacionam a uma estrutura capaz de aplicar mais processamento a entradas difíceis
- Em ablation do AF2, a importância de recycling apareceu, mas houve pouca discussão sobre a importância de gating
Ainda não há comentários.