2 pontos por GN⁺ 2023-08-10 | 1 comentários | Compartilhar no WhatsApp
  • Brian Kitano construiu uma versão reduzida do Llama usando TinyShakespeare e concluiu que, para implementar um artigo com segurança, o ideal é começar com um modelo pequeno, trocar os componentes um a um e treinar e avaliar a cada etapa
  • Ele preparou primeiro funções auxiliares de validação como divisão de dados, geração de batches, avaliação de perda e função de geração, confirmou com um modelo simples que a compilação e o treinamento funcionavam e só então adicionou os componentes do Llama
  • Ao adicionar RMSNorm, RoPE e SwiGLU em sequência, verificou se cada camada se comportava como esperado por meio de shape de tensores, propriedades das fórmulas e mapas de atenção
  • Ao remover a causal mask na atenção com RoPE, a perda de validação caiu até 0,16, mas a qualidade da geração piorou; a causa foi vazamento de informação por olhar tokens futuros
  • O Llama reduzido final tinha 4 blocos e cerca de 2,37 milhões de parâmetros, reduziu a perda de validação para cerca de 1,0, e também exigiu checagem do fluxo de gradientes e do agendamento da taxa de aprendizado

Começar pequeno e ganhar confiança de forma iterativa

  • O ponto central para implementar um artigo é começar com um modelo pequeno, trocar os componentes um a um e repetir treinamento e avaliação a cada mudança
  • Primeiro, ele preparou funções auxiliares para verificar o modelo de forma quantitativa
    • divisão de dados
    • loop de treinamento
    • visualização da perda
    • avaliação da perda de validação
  • Em vez de portar todos os componentes do artigo de uma vez, ele também criou uma função de avaliação qualitativa para observar os resultados de geração com um modelo simples e rápido que já conhecia bem
  • As camadas de tensores foram verificadas com .shape, assert e plt.imshow; em vez de partir direto para otimização de multiplicação de matrizes, ele primeiro conferiu manualmente os resultados esperados e depois usou funções do torch para ganhar eficiência
  • É preciso testar variando batch size, comprimento de sequência e dimensão do embedding; código que só funciona para um único tamanho pode quebrar na inferência

Dataset e configuração básica

  • O alvo da implementação é uma versão muito reduzida do Llama, da Meta AI, e os dados de treinamento vêm do TinyShakespeare
  • O Llama é treinado com 1,4T tokens, mas aqui foi usado o TinyShakespeare, com cerca de 1,11 milhão de caracteres
  • O Llama original usa o tokenizador byte-pair encoding SentencePiece, mas esta implementação usa um tokenizador em nível de caractere mais simples
    • vocabulary size de 65
    • como o dataset é pequeno, o armazenamento em memória não foi otimizado separadamente
  • Um dicionário MASTER_CONFIG gerencia configurações do modelo como vocab_size, batch_size, context_window e d_model
    • o objetivo é reduzir constantes e números mágicos e tornar o código mais legível
  • A função get_batches divide os dados em 80% treino, 10% validação e 10% teste, e cria a entrada x e o rótulo y deslocado em um caractere a partir de um ponto inicial aleatório

Confirmando compilação e treinamento com um modelo básico

  • O primeiro modelo foi o SimpleBrokenModel, composto por embedding e uma rede feed-forward simples
    • nn.Embedding
    • Linear
    • ReLU
    • Linear
  • Na implementação de um artigo, dizer que o modelo “funciona” exige duas condições
    • compilar: os shapes dos tensores precisam encaixar entre as camadas
    • treinar: a perda precisa realmente cair
  • A função evaluate_loss calcula a perda média amostrando 10 batches das divisões de treino e validação
  • Após 1000 epochs, o SimpleBrokenModel ficou com perda de validação em torno de 3,94, quase sem cair a partir da cross-entropy inicial de 4,17
  • O problema era que valores já processados por softmax estavam sendo passados para F.cross_entropy
    • O F.cross_entropy do PyTorch recebe diretamente logits não normalizados
    • Ao remover o softmax, o SimpleModel reduziu a perda de validação para cerca de 2,51
  • Depois disso, ele adicionou a função generate para inspecionar diretamente os caracteres produzidos pelo modelo; o modelo básico ainda não era ideal, mas já apresentava queda na perda de validação

Componente 1 do Llama: RMSNorm

  • Em relação ao Transformer original, o Llama usa três mudanças arquiteturais principais
    • RMSNorm com pre-normalization
    • Rotary embeddings
    • função de ativação SwiGLU
  • O Transformer original usa BatchNormalization, enquanto o Llama usa RMSNorm, que escala pelo variance sem centralizar o vetor
  • Enquanto o Transformer original aplica normalização na saída da camada de atenção em esquema post-normalization, o Llama aplica primeiro na entrada, em esquema pre-normalization
  • O RMSNorm implementado assume entrada com shape (batch, seq_len, d_model)
  • O resultado do RMSNorm foi testado pela propriedade de que a norma da camada se torna a raiz quadrada do número de elementos da camada
    • assert
    • comparação linha a linha
    • torch.allclose
  • O SimpleModel_RMS, que adiciona RMSNorm ao modelo básico, reduziu levemente a perda de validação para cerca de 2,5015

Componente 2 do Llama: RoPE e causal mask

  • RoPE é uma forma de codificação posicional para Transformers que representa a posição do token por meio da rotação do embedding
  • get_rotary_matrix cria uma matriz de rotação por posição para a janela de contexto e a dimensão do embedding
  • A implementação de RoPE foi testada com a seguinte propriedade
    • o produto interno entre dois vetores rotacionados nas posições m e n deve coincidir com a rotação relativa n-m
  • RoPEAttentionHead cria w_q, w_k e w_v, aplica a rotação RoPE em query e key, e depois usa F.scaled_dot_product_attention
  • É preciso tomar cuidado com a diferença de shape dos tensores entre treinamento e inferência
    • no treinamento, muitas vezes o shape segue a configuração, como (config['batch_size'], config['context_window'], config['d_model'])
    • na inferência, pode haver um único exemplo como (1, 1, config['d_model'])
    • dentro de forward, a indexação deve se basear no shape da entrada, e não nos valores fixos da configuração do modelo
  • O modelo que adicionou atenção multi-head com RoPE sem causal mask teve uma queda brusca da perda de validação até 0,1623, mas gerava saídas ruins como OOOO... e IIII...
  • Ao observar o mapa de atenção, ficou claro que todas as posições estavam consultando todas as outras, causando vazamento de informação ao olhar tokens futuros na previsão do próximo token
  • Ao trocar para RoPEMaskedAttentionHead, aplicando is_causal=True em F.scaled_dot_product_attention, a atenção na parte triangular superior, correspondente ao futuro, caiu quase a zero
  • Com a causal mask, a perda de validação foi para 2,0815, e com mais tempo de treinamento caiu até 1,8985

Componente 3 do Llama: SwiGLU e empilhamento de blocos

  • O Llama substitui a não linearidade ReLU pela função de ativação SwiGLU
  • O SwiGLU implementado é uma Swish-gated linear unit, usando duas transformações lineares e um parâmetro beta aprendível
  • O RopeModel, com SwiGLU na parte feed-forward, tinha 592.706 parâmetros e perda de validação em torno de 1,8963
  • Em seguida, foi criado o LlamaBlock, reunindo a seguinte estrutura em um único bloco
    • RMSNorm com pre-normalization
    • atenção multi-head com RoPE mascarada
    • residual connection
    • RMSNorm com pre-normalization
    • feed-forward com SwiGLU
    • residual connection
  • O modelo final Llama foi configurado com n_layers=4 e empilha 4 LlamaBlock com nn.Sequential baseado em OrderedDict
  • O modelo final tem 2.370.246 parâmetros, e os resultados de treinamento foram os seguintes
    • após o treinamento inicial com 4 camadas, perda de validação de 1,5532
    • após treinar mais por 10.000 epochs, perda de validação de 1,1479
    • após treinamento adicional, perda de validação de 0,9997
    • a perda em um batch da divisão de teste foi 1,2358

Resultado de geração e pontos de depuração

  • O modelo final gera nomes, quebras de linha e fragmentos de palavras parecidos com o estilo de Shakespeare, mas a qualidade real das sentenças ainda é limitada
  • A perda de cross-entropy pode ser interpretada intuitivamente do ponto de vista da escolha de tokens
    • a perda inicial de 4,17 é próxima de uma escolha aleatória com vocabulary size 65
    • a perda de 1,08 pode ser entendida como escolher aleatoriamente entre cerca de 2,9 tokens
  • O fluxo de gradientes foi verificado com a função show_grads
    • ela calcula a proporção de gradientes com valor absoluto pequeno em cada parâmetro
    • se a maioria dos gradientes dos parâmetros não estiver próxima de zero, o fluxo está em boas condições
  • O Llama original usa agendamento de aprendizado com Cosine Annealing, mas nesta implementação os resultados experimentais foram piores
  • No experimento com Cosine Annealing, mesmo com tolerance muito baixa, o attention bias quase não recebeu sinal; como a razão não ficou clara, na prática é mais seguro começar de forma simples

1 comentários

 
GN⁺ 2023-08-10
Opiniões no Hacker News
  • Parece haver um bug na implementação do SwiGLU: no artigo de referência, o beta da feed-forward network não é um valor treinável, mas uma constante, e é definido como FFnSwiGLU = Swish1...
    Com base na equação 6 de https://arxiv.org/pdf/2002.05202.pdf
    Na implementação oficial do Llama, o beta constante também foi removido: https://github.com/facebookresearch/llama/blob/main/llama/mo...
    Pelas linhas "feedforward.1.beta', 0.0" no log do blog, durante o treinamento o beta degenerou para 0, mas originalmente deveria ser a constante 1

    • Isso mostra como é difícil implementar corretamente uma rede neural Transformer. Dá para errar em várias etapas, e normalmente isso aparece apenas como “um desempenho um pouco pior que o original”, o que torna difícil ter certeza
      Muitas vezes a rede se adapta às mudanças, intencionais ou não, e, após o treinamento, várias variações de arquitetura acabam se comportando de forma parecida, então às vezes fica ambíguo se ela precisa coincidir exatamente com a original
      Uma forma de encontrar esse tipo de erro é fazer os valores de saída baterem exatamente com uma implementação de referência. Mesmo com pesos aleatórios, como nos modelos tiny-random da HuggingFace, a saída deve ser exatamente igual; se não for, é sinal de bug
      Só que esse método funciona bem apenas para bugs que aparecem durante a inferência; problemas que ocorrem apenas no processamento de dados, no otimizador ou durante o treinamento são mais difíceis de detectar
    • Em Transformers, acho que os valores de viés em geral não batem muito bem
      Pessoalmente, acho que é por causa de propriedades autorregressivas e semelhantes a ODE, mas não tenho certeza suficiente
  • O trabalho é excelente, mas os SimpleBrokenModel e SimpleModel iniciais têm bastante computação desperdiçada. A sequência é embedding 65 -> 128, linear 128 -> 128, ReLU, linear 128 -> 65; como não há não linearidade entre as duas primeiras camadas, e ambas são lineares, a segunda camada linear é essencialmente inútil
    Esse modelo acaba sendo equivalente a um MLP clássico de uma única camada oculta e, em termos de FLOPS, desperdiça 128*128=16k operações de um total de 128*128+65*128=24k

    • Parece que não sou o único que ainda está pegando o jeito das não linearidades. Fico curioso se a melhor correção aqui seria colocar ReLU ou SwiGLU entre o embedding e a primeira camada linear, ou simplesmente remover a camada linear
      A camada de embedding é uma estrutura especial que transforma índices de tokens em vetores de embedding, então imagino que não dê para removê-la
  • No geral, mostra bem os princípios básicos. Gosto especialmente de “use .shape religiosamente. assert e plt.imshow são seus amigos”, e as pré e pós-condições de shape devem sempre ser verificadas com assert
    Também fico curioso se bear ou typeguard dão suporte a esse tipo de verificação via decorators
    Mas a parte “escolha um modelo pequeno, simples e rápido e crie helpers para avaliá-lo qualitativamente” talvez quisesse dizer avaliação quantitativa. Assim você cria uma baseline numérica para comparar com técnicas mais avançadas
    O conselho de implementar os componentes do artigo um por um também deveria ser mais preciso. Artigos normalmente tentam várias mudanças de uma vez e depois mostram a contribuição de cada componente por meio de estudos de ablação; então acho melhor começar pelas mudanças centrais de arquitetura e, seguindo a ordem de maior impacto nos estudos de ablação e respeitando as dependências, avaliar cada mudança atômica

    • Em vez de bear ou typeguard, graças à https://peps.python.org/pep-0646/, parte disso pode ser colocada diretamente em anotações de tipo do Python
      Por exemplo, dá para expressar o shape por eixo no tipo, como ndarray[float, Dim1, *Shape], e sobrecarregar o shape de retorno conforme o valor de axis
    • Não conheço bem o PyTorch, mas, da última vez que verifiquei, ele não fazia isso; já o Jax dá suporte a verificações básicas de runtime de shapes de matrizes por meio de bear / typeguard
      Ainda assim, parece difícil o Python chegar ao nível da Julia. O sistema de tipos da Julia permite garantir com muito mais facilidade que os tamanhos das matrizes são compatíveis
  • Fico curioso sobre qual é o princípio para usar SwiGLU em vez de ReLU. Não sei se os autores simplesmente testaram todas as funções não lineares possíveis ou se há uma razão mais profunda

    • Como em muitas pesquisas, se não houver uma explicação clara sustentada por um estudo rigoroso, é bem provável que tenham feito uma busca por hill climbing aleatória com mudanças de uma linha que pareciam legais, e parado quando chegou a hora de escrever o artigo e fazer os estudos de ablação
  • Como o bearblog está sofrendo um DDoS, deixo o repositório: https://github.com/bkitano/llama-from-scratch

  • Como alguém que está aprendendo IA, tentei resumir brevemente os termos que aparecem no texto. Token é um identificador inteiro que representa um pedaço de texto e, em LLMs, costuma-se agrupar fragmentos de caracteres usados com frequência dentro de um vocabulário de tamanho limitado
    A função de perda é um valor que mede a diferença entre a previsão e a resposta correta, e quanto menor, melhor. PyTorch é uma biblioteca para lidar com tensores e redes neurais, e um tensor é um array numérico multidimensional que inclui escalares, vetores e matrizes
    Uma rede neural é uma estrutura de conexões entre neurônios com pesos e vieses, e uma camada linear é uma estrutura simples em que todas as entradas e saídas estão conectadas. ReLU é uma função de ativação como Math.max(0, x); como empilhar apenas camadas lineares acaba sendo equivalente a uma única função linear, ela introduz não linearidade para aumentar a capacidade de aprendizado
    O gradiente é uma quantidade de variação numérica calculada durante o treinamento para tornar o modelo mais preciso, e a normalização em lote é um método que ajuda o aprendizado ajustando os números em fluxo. A codificação posicional informa, por meio de vetores, as posições relativas dos tokens
    Em Python, o operador @ é um alias de __matmul__ e é usado para multiplicação de matrizes. Uma época é treinar uma vez sobre todo o dataset, e um lote é a quantidade de dados inserida de uma vez antes de atualizar os parâmetros
    Atenção é o componente central que faz um LLM funcionar: ela processa os tokens de entrada em paralelo para criar tensores intermediários, que depois são usados para gerar os tokens de saída

    • Fora da área, talvez as pessoas não saibam o que “Karpathy” significa. Se Andrej Karpathy for apresentado com contexto, como “comunicador científico e pesquisador”, fica mais claro que a ideia é consultar seus textos ou vídeos
    • Para iniciantes, é mais correto ver tokens não simplesmente como identificadores inteiros de pedaços de texto, mas como fragmentos de palavras comuns o suficiente para serem úteis por si só
      Por exemplo, writ, comum a writing, written e writer, pode se tornar um token, e writer pode ser tokenizado como writ e er
      Embedding é a etapa que transforma esses tokens em representações numéricas próprias
    • Compor funções lineares resulta novamente em uma função linear. Portanto, se tudo for linear, mesmo empilhando várias camadas, todas exceto uma acabam sendo desperdício; para evitar isso, é necessária não linearidade
    • Além da série de vídeos do Karpathy e do accompanying repo, fico curioso para saber se houve outros materiais ou livros especialmente úteis na jornada de aprendizado
    • Fico curioso sobre o que exatamente a normalização em lote faz e como ela ajuda
  • Se houver uma implementação existente do modelo e checkpoints, a maneira mais eficaz de verificar se a sua implementação está correta é carregar esse checkpoint e comparar os valores de saída
    Se a saída não bater, em geral é porque algum detalhe da implementação está errado, e dá para seguir sistematicamente cada camada até encontrar a diferença real. No processo, você também pode acabar descobrindo algo estranho na implementação existente
    Isso diz respeito ao modelo em si; o treinamento é um eixo separado. Ainda assim, se os hiperparâmetros estiverem mais ou menos parecidos, quando a implementação do modelo está correta, em geral as coisas ficam bem

  • Tanto a forma de ler artigos quanto o conteúdo desse artigo são bons, e também recomendo a série Makemore do Karpathy

  • Os conselhos resumidos são muito bons, e acho que a recomendação de fazer assert do shape dos tensores se aplica a qualquer biblioteca geral de álgebra linear. Ao escrever código complexo de álgebra linear, é muito importante avançar em pequenos passos e programar de forma defensiva
    Programar álgebra linear em linguagens mainstream é horrível porque não há verificação de shape em tempo de compilação. O shape de um tensor deveria fazer parte do tipo, e tentar multiplicar 3x4 por 3x4 sem transpor deveria sequer compilar
    Rodar um cálculo longo e depois falhar em uma operação por incompatibilidade de dimensões é realmente o pior caso
    Também acho que, nos tensores do PyTorch, o dispositivo deveria ser tipado estaticamente. Hoje, se você tenta multiplicar um tensor na memória da CPU por um tensor na memória da GPU, recebe um erro em tempo de execução