Llama do zero: como implementar um artigo sem sofrer
(blog.briankitano.com)- Brian Kitano construiu uma versão reduzida do Llama usando TinyShakespeare e concluiu que, para implementar um artigo com segurança, o ideal é começar com um modelo pequeno, trocar os componentes um a um e treinar e avaliar a cada etapa
- Ele preparou primeiro funções auxiliares de validação como divisão de dados, geração de batches, avaliação de perda e função de geração, confirmou com um modelo simples que a compilação e o treinamento funcionavam e só então adicionou os componentes do Llama
- Ao adicionar RMSNorm, RoPE e SwiGLU em sequência, verificou se cada camada se comportava como esperado por meio de shape de tensores, propriedades das fórmulas e mapas de atenção
- Ao remover a causal mask na atenção com RoPE, a perda de validação caiu até 0,16, mas a qualidade da geração piorou; a causa foi vazamento de informação por olhar tokens futuros
- O Llama reduzido final tinha 4 blocos e cerca de 2,37 milhões de parâmetros, reduziu a perda de validação para cerca de 1,0, e também exigiu checagem do fluxo de gradientes e do agendamento da taxa de aprendizado
Começar pequeno e ganhar confiança de forma iterativa
- O ponto central para implementar um artigo é começar com um modelo pequeno, trocar os componentes um a um e repetir treinamento e avaliação a cada mudança
- Primeiro, ele preparou funções auxiliares para verificar o modelo de forma quantitativa
- divisão de dados
- loop de treinamento
- visualização da perda
- avaliação da perda de validação
- Em vez de portar todos os componentes do artigo de uma vez, ele também criou uma função de avaliação qualitativa para observar os resultados de geração com um modelo simples e rápido que já conhecia bem
- As camadas de tensores foram verificadas com
.shape,asserteplt.imshow; em vez de partir direto para otimização de multiplicação de matrizes, ele primeiro conferiu manualmente os resultados esperados e depois usou funções dotorchpara ganhar eficiência - É preciso testar variando batch size, comprimento de sequência e dimensão do embedding; código que só funciona para um único tamanho pode quebrar na inferência
Dataset e configuração básica
- O alvo da implementação é uma versão muito reduzida do Llama, da Meta AI, e os dados de treinamento vêm do TinyShakespeare
- O Llama é treinado com 1,4T tokens, mas aqui foi usado o TinyShakespeare, com cerca de 1,11 milhão de caracteres
- O Llama original usa o tokenizador byte-pair encoding SentencePiece, mas esta implementação usa um tokenizador em nível de caractere mais simples
- vocabulary size de 65
- como o dataset é pequeno, o armazenamento em memória não foi otimizado separadamente
- Um dicionário
MASTER_CONFIGgerencia configurações do modelo comovocab_size,batch_size,context_windowed_model- o objetivo é reduzir constantes e números mágicos e tornar o código mais legível
- A função
get_batchesdivide os dados em 80% treino, 10% validação e 10% teste, e cria a entradaxe o rótuloydeslocado em um caractere a partir de um ponto inicial aleatório
Confirmando compilação e treinamento com um modelo básico
- O primeiro modelo foi o
SimpleBrokenModel, composto por embedding e uma rede feed-forward simplesnn.EmbeddingLinearReLULinear
- Na implementação de um artigo, dizer que o modelo “funciona” exige duas condições
- compilar: os shapes dos tensores precisam encaixar entre as camadas
- treinar: a perda precisa realmente cair
- A função
evaluate_losscalcula a perda média amostrando 10 batches das divisões de treino e validação - Após 1000 epochs, o
SimpleBrokenModelficou com perda de validação em torno de 3,94, quase sem cair a partir da cross-entropy inicial de 4,17 - O problema era que valores já processados por softmax estavam sendo passados para
F.cross_entropy- O
F.cross_entropydo PyTorch recebe diretamente logits não normalizados - Ao remover o softmax, o
SimpleModelreduziu a perda de validação para cerca de 2,51
- O
- Depois disso, ele adicionou a função
generatepara inspecionar diretamente os caracteres produzidos pelo modelo; o modelo básico ainda não era ideal, mas já apresentava queda na perda de validação
Componente 1 do Llama: RMSNorm
- Em relação ao Transformer original, o Llama usa três mudanças arquiteturais principais
- RMSNorm com pre-normalization
- Rotary embeddings
- função de ativação SwiGLU
- O Transformer original usa BatchNormalization, enquanto o Llama usa RMSNorm, que escala pelo variance sem centralizar o vetor
- Enquanto o Transformer original aplica normalização na saída da camada de atenção em esquema post-normalization, o Llama aplica primeiro na entrada, em esquema pre-normalization
- O
RMSNormimplementado assume entrada com shape(batch, seq_len, d_model) - O resultado do RMSNorm foi testado pela propriedade de que a norma da camada se torna a raiz quadrada do número de elementos da camada
assert- comparação linha a linha
torch.allclose
- O
SimpleModel_RMS, que adiciona RMSNorm ao modelo básico, reduziu levemente a perda de validação para cerca de 2,5015
Componente 2 do Llama: RoPE e causal mask
- RoPE é uma forma de codificação posicional para Transformers que representa a posição do token por meio da rotação do embedding
get_rotary_matrixcria uma matriz de rotação por posição para a janela de contexto e a dimensão do embedding- A implementação de RoPE foi testada com a seguinte propriedade
- o produto interno entre dois vetores rotacionados nas posições
mendeve coincidir com a rotação relativan-m
- o produto interno entre dois vetores rotacionados nas posições
RoPEAttentionHeadcriaw_q,w_kew_v, aplica a rotação RoPE em query e key, e depois usaF.scaled_dot_product_attention- É preciso tomar cuidado com a diferença de shape dos tensores entre treinamento e inferência
- no treinamento, muitas vezes o shape segue a configuração, como
(config['batch_size'], config['context_window'], config['d_model']) - na inferência, pode haver um único exemplo como
(1, 1, config['d_model']) - dentro de
forward, a indexação deve se basear no shape da entrada, e não nos valores fixos da configuração do modelo
- no treinamento, muitas vezes o shape segue a configuração, como
- O modelo que adicionou atenção multi-head com RoPE sem causal mask teve uma queda brusca da perda de validação até 0,1623, mas gerava saídas ruins como
OOOO...eIIII... - Ao observar o mapa de atenção, ficou claro que todas as posições estavam consultando todas as outras, causando vazamento de informação ao olhar tokens futuros na previsão do próximo token
- Ao trocar para
RoPEMaskedAttentionHead, aplicandois_causal=TrueemF.scaled_dot_product_attention, a atenção na parte triangular superior, correspondente ao futuro, caiu quase a zero - Com a causal mask, a perda de validação foi para 2,0815, e com mais tempo de treinamento caiu até 1,8985
Componente 3 do Llama: SwiGLU e empilhamento de blocos
- O Llama substitui a não linearidade ReLU pela função de ativação SwiGLU
- O
SwiGLUimplementado é uma Swish-gated linear unit, usando duas transformações lineares e um parâmetrobetaaprendível - O
RopeModel, com SwiGLU na parte feed-forward, tinha 592.706 parâmetros e perda de validação em torno de 1,8963 - Em seguida, foi criado o
LlamaBlock, reunindo a seguinte estrutura em um único bloco- RMSNorm com pre-normalization
- atenção multi-head com RoPE mascarada
- residual connection
- RMSNorm com pre-normalization
- feed-forward com SwiGLU
- residual connection
- O modelo final
Llamafoi configurado comn_layers=4e empilha 4LlamaBlockcomnn.Sequentialbaseado emOrderedDict - O modelo final tem 2.370.246 parâmetros, e os resultados de treinamento foram os seguintes
- após o treinamento inicial com 4 camadas, perda de validação de 1,5532
- após treinar mais por 10.000 epochs, perda de validação de 1,1479
- após treinamento adicional, perda de validação de 0,9997
- a perda em um batch da divisão de teste foi 1,2358
Resultado de geração e pontos de depuração
- O modelo final gera nomes, quebras de linha e fragmentos de palavras parecidos com o estilo de Shakespeare, mas a qualidade real das sentenças ainda é limitada
- A perda de cross-entropy pode ser interpretada intuitivamente do ponto de vista da escolha de tokens
- a perda inicial de 4,17 é próxima de uma escolha aleatória com vocabulary size 65
- a perda de 1,08 pode ser entendida como escolher aleatoriamente entre cerca de 2,9 tokens
- O fluxo de gradientes foi verificado com a função
show_grads- ela calcula a proporção de gradientes com valor absoluto pequeno em cada parâmetro
- se a maioria dos gradientes dos parâmetros não estiver próxima de zero, o fluxo está em boas condições
- O Llama original usa agendamento de aprendizado com Cosine Annealing, mas nesta implementação os resultados experimentais foram piores
- No experimento com Cosine Annealing, mesmo com tolerance muito baixa, o attention bias quase não recebeu sinal; como a razão não ficou clara, na prática é mais seguro começar de forma simples
1 comentários
Opiniões no Hacker News
Parece haver um bug na implementação do SwiGLU: no artigo de referência, o beta da feed-forward network não é um valor treinável, mas uma constante, e é definido como
FFnSwiGLU = Swish1...Com base na equação 6 de https://arxiv.org/pdf/2002.05202.pdf
Na implementação oficial do Llama, o beta constante também foi removido: https://github.com/facebookresearch/llama/blob/main/llama/mo...
Pelas linhas
"feedforward.1.beta', 0.0"no log do blog, durante o treinamento o beta degenerou para 0, mas originalmente deveria ser a constante 1Muitas vezes a rede se adapta às mudanças, intencionais ou não, e, após o treinamento, várias variações de arquitetura acabam se comportando de forma parecida, então às vezes fica ambíguo se ela precisa coincidir exatamente com a original
Uma forma de encontrar esse tipo de erro é fazer os valores de saída baterem exatamente com uma implementação de referência. Mesmo com pesos aleatórios, como nos modelos tiny-random da HuggingFace, a saída deve ser exatamente igual; se não for, é sinal de bug
Só que esse método funciona bem apenas para bugs que aparecem durante a inferência; problemas que ocorrem apenas no processamento de dados, no otimizador ou durante o treinamento são mais difíceis de detectar
Pessoalmente, acho que é por causa de propriedades autorregressivas e semelhantes a ODE, mas não tenho certeza suficiente
O trabalho é excelente, mas os
SimpleBrokenModeleSimpleModeliniciais têm bastante computação desperdiçada. A sequência éembedding 65 -> 128,linear 128 -> 128,ReLU,linear 128 -> 65; como não há não linearidade entre as duas primeiras camadas, e ambas são lineares, a segunda camada linear é essencialmente inútilEsse modelo acaba sendo equivalente a um MLP clássico de uma única camada oculta e, em termos de FLOPS, desperdiça
128*128=16koperações de um total de128*128+65*128=24kA camada de embedding é uma estrutura especial que transforma índices de tokens em vetores de embedding, então imagino que não dê para removê-la
No geral, mostra bem os princípios básicos. Gosto especialmente de “use
.shapereligiosamente.asserteplt.imshowsão seus amigos”, e as pré e pós-condições de shape devem sempre ser verificadas com assertTambém fico curioso se
bearoutypeguarddão suporte a esse tipo de verificação via decoratorsMas a parte “escolha um modelo pequeno, simples e rápido e crie helpers para avaliá-lo qualitativamente” talvez quisesse dizer avaliação quantitativa. Assim você cria uma baseline numérica para comparar com técnicas mais avançadas
O conselho de implementar os componentes do artigo um por um também deveria ser mais preciso. Artigos normalmente tentam várias mudanças de uma vez e depois mostram a contribuição de cada componente por meio de estudos de ablação; então acho melhor começar pelas mudanças centrais de arquitetura e, seguindo a ordem de maior impacto nos estudos de ablação e respeitando as dependências, avaliar cada mudança atômica
bearoutypeguard, graças à https://peps.python.org/pep-0646/, parte disso pode ser colocada diretamente em anotações de tipo do PythonPor exemplo, dá para expressar o shape por eixo no tipo, como
ndarray[float, Dim1, *Shape], e sobrecarregar o shape de retorno conforme o valor deaxisbear/typeguardAinda assim, parece difícil o Python chegar ao nível da Julia. O sistema de tipos da Julia permite garantir com muito mais facilidade que os tamanhos das matrizes são compatíveis
Fico curioso sobre qual é o princípio para usar SwiGLU em vez de ReLU. Não sei se os autores simplesmente testaram todas as funções não lineares possíveis ou se há uma razão mais profunda
Como o bearblog está sofrendo um DDoS, deixo o repositório: https://github.com/bkitano/llama-from-scratch
Como alguém que está aprendendo IA, tentei resumir brevemente os termos que aparecem no texto. Token é um identificador inteiro que representa um pedaço de texto e, em LLMs, costuma-se agrupar fragmentos de caracteres usados com frequência dentro de um vocabulário de tamanho limitado
A função de perda é um valor que mede a diferença entre a previsão e a resposta correta, e quanto menor, melhor. PyTorch é uma biblioteca para lidar com tensores e redes neurais, e um tensor é um array numérico multidimensional que inclui escalares, vetores e matrizes
Uma rede neural é uma estrutura de conexões entre neurônios com pesos e vieses, e uma camada linear é uma estrutura simples em que todas as entradas e saídas estão conectadas. ReLU é uma função de ativação como
Math.max(0, x); como empilhar apenas camadas lineares acaba sendo equivalente a uma única função linear, ela introduz não linearidade para aumentar a capacidade de aprendizadoO gradiente é uma quantidade de variação numérica calculada durante o treinamento para tornar o modelo mais preciso, e a normalização em lote é um método que ajuda o aprendizado ajustando os números em fluxo. A codificação posicional informa, por meio de vetores, as posições relativas dos tokens
Em Python, o operador
@é um alias de__matmul__e é usado para multiplicação de matrizes. Uma época é treinar uma vez sobre todo o dataset, e um lote é a quantidade de dados inserida de uma vez antes de atualizar os parâmetrosAtenção é o componente central que faz um LLM funcionar: ela processa os tokens de entrada em paralelo para criar tensores intermediários, que depois são usados para gerar os tokens de saída
Por exemplo,
writ, comum awriting,writtenewriter, pode se tornar um token, ewriterpode ser tokenizado comowriteerEmbedding é a etapa que transforma esses tokens em representações numéricas próprias
Se houver uma implementação existente do modelo e checkpoints, a maneira mais eficaz de verificar se a sua implementação está correta é carregar esse checkpoint e comparar os valores de saída
Se a saída não bater, em geral é porque algum detalhe da implementação está errado, e dá para seguir sistematicamente cada camada até encontrar a diferença real. No processo, você também pode acabar descobrindo algo estranho na implementação existente
Isso diz respeito ao modelo em si; o treinamento é um eixo separado. Ainda assim, se os hiperparâmetros estiverem mais ou menos parecidos, quando a implementação do modelo está correta, em geral as coisas ficam bem
Tanto a forma de ler artigos quanto o conteúdo desse artigo são bons, e também recomendo a série Makemore do Karpathy
Os conselhos resumidos são muito bons, e acho que a recomendação de fazer assert do shape dos tensores se aplica a qualquer biblioteca geral de álgebra linear. Ao escrever código complexo de álgebra linear, é muito importante avançar em pequenos passos e programar de forma defensiva
Programar álgebra linear em linguagens mainstream é horrível porque não há verificação de shape em tempo de compilação. O shape de um tensor deveria fazer parte do tipo, e tentar multiplicar
3x4por3x4sem transpor deveria sequer compilarRodar um cálculo longo e depois falhar em uma operação por incompatibilidade de dimensões é realmente o pior caso
Também acho que, nos tensores do PyTorch, o dispositivo deveria ser tipado estaticamente. Hoje, se você tenta multiplicar um tensor na memória da CPU por um tensor na memória da GPU, recebe um erro em tempo de execução