Como escalar seu modelo: uma visão sistêmica de LLMs em TPUs
(jax-ml.github.io)- Otimizar o desempenho de deep learning em grande escala pode parecer “alquimia”, mas na prática é possível melhorar a eficiência do modelo com princípios simples e compreensíveis
- De um único acelerador a dezenas de milhares deles, princípios relativamente simples se aplicam em todos os casos, e entendê-los permite realizar tarefas úteis como:
- Estimar aproximadamente o quão perto cada parte do modelo está do valor ótimo teórico
- Criar uma base para escolher diferentes técnicas de paralelização em várias escalas
- Estimar custo e tempo necessários para treinar e executar grandes modelos Transformer
- Projetar algoritmos que aproveitem as características de um hardware específico
- Projetar hardware com uma compreensão clara dos limites de desempenho dos algoritmos atuais
- Conhecimentos prévios necessários
- É necessário entender os conceitos básicos de LLMs e da arquitetura Transformer
- Não é obrigatório entender como operações em larga escala funcionam
- É melhor ainda se você tiver conhecimento básico de treinamento de LLMs e experiência com JAX
- Recomenda-se consultar o post de blog sobre a arquitetura Transformer e os slides sobre escalabilidade de LLMs em JAX
- Objetivos
- Desenvolver a capacidade de estimar qual forma de paralelização do modelo funciona melhor no hardware disponível
- Desenvolver a capacidade de calcular aproximadamente o tempo e o custo de treinamento e inferência
Por que isso importa
- Até 3 ou 4 anos atrás, a maioria dos pesquisadores de ML não precisava conhecer a fundo esse tipo de otimização em grande escala
- Hoje, até modelos “pequenos” operam próximos dos limites do hardware, então entender formas eficientes de trabalhar em larga escala se tornou essencial
- A história de ML pode ser vista como um processo de avanço conjunto entre inovação em sistemas e melhorias de software
- Como os modelos Transformer recentes usam o hardware até o limite, sem entender a eficiência do modelo há grande chance de que novas arquiteturas ou pesquisas fracassem na aplicação real
- Mesmo que se obtenha 20% de melhora em benchmark, se a eficiência de hardware cair 20%, no fim a utilidade prática será baixa
- O objetivo central do escalonamento de modelos é fazer com que o throughput aumente linearmente ao elevar o número de chips (aceleradores)
- Isso é chamado de "strong scaling"
- Adicionar chips reduz o tempo de computação, mas gera custo de comunicação entre chips
- Se a comunicação levar mais tempo do que a computação, entra-se em um estado "communication bound", no qual o strong scaling se torna impossível
- Se for possível prever onde esses gargalos vão surgir com base em um bom entendimento do hardware, dá para projetar ou reorganizar o modelo para evitá-los
- O objetivo deste livro é explicar como o hardware de TPU (e GPU) funciona e como a arquitetura Transformer evoluiu para funcionar bem no hardware atual
- A expectativa é que isso seja útil tanto para pesquisadores que projetam novas arquiteturas quanto para engenheiros que tentam executar os LLMs da geração atual com rapidez
Visão geral
- Este texto é organizado da seguinte forma
- A Seção 1 explica, por meio de análise roofline, os fatores que determinam os limites de desempenho do modelo (comunicação, computação e memória)
- As Seções 2 e 3 tratam da estrutura interna de TPUs e GPUs e de como os chips se conectam entre si
- Com isso, respondem a perguntas como
- Quão rapidamente uma multiplicação de matrizes de determinado tamanho pode ser executada em teoria
- Em que ponto a computação passa a ficar limitada pela largura de banda de memória ou de comunicação
- Como um cluster de TPUs é conectado e quanto tempo aproximadamente leva para mover dados de um chip para outro
- Como multiplicar matrizes distribuídas de forma eficiente
- Com isso, respondem a perguntas como
- A Seção 4 aborda em detalhe a formulação da arquitetura Transformer (tamanhos de matrizes, número de parâmetros, FLOPs)
- As Seções 5 e 7 são o núcleo e apresentam diferentes formas de paralelizar modelos em vários chips
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- Também cobre técnicas de economia de memória como ZeRO, Rematerialisation, Host offload e Gradient accumulation
- As Seções 6 e 8 usam o modelo LLaMA-3 como exemplo de treinamento e inferência em TPUs, apresentando custo, tempo e formas de configuração na prática
- Por fim, as Seções 9 e 10 mostram métodos práticos para fazer profiling, depuração e paralelização de modelos em JAX
Mais detalhes: resumo das principais seções do livro
-
Parte 1: Preliminaries
-
Seção 1: Introdução simples à análise Roofline
- Os três fatores que limitam algoritmos: computação, comunicação e memória
- Aprende-se a partir disso como estimar o limite superior da velocidade de computação
-
Seção 2: Uma forma de olhar para TPUs
- Como a TPU realiza computação
- O que é uma estrutura de systolic array
- Entendimento básico de como a TPU oferece largura de banda de memória e comunicação
-
Seção 3: Matrizes distribuídas e multiplicação distribuída
- Técnica de armazenar os parâmetros do modelo distribuídos entre vários chips (Sharding)
- Como lidar com comunicação e gargalos que surgem em operações com matrizes distribuídas
-
-
Parte 2: Transformers
-
Seção 4: Organização das fórmulas necessárias de Transformer
- Que forma concreta a multiplicação de matrizes assume em Transformers
- Como calcular número de parâmetros, FLOPs, tamanho do cache KV etc.
- Entender quanta computação a operação de Attention exige em comparação com blocos Feed-Forward
-
Seção 5: Estratégias de paralelização para treinamento de Transformers
- Introdução a técnicas de Data parallel, Tensor parallel, Pipeline parallel e Expert parallel
- Medidas para economizar memória como ZeRO(FSDP), Rematerialisation, Gradient accumulation e Host offload
- Construção dos conceitos para configurar a paralelização de acordo com o tamanho do modelo e o número de chips
-
Seção 6: Aplicação do treinamento de LLaMA 3 em TPU
- Estimativa de tempo e custo assumindo o treinamento do modelo LLaMA 3 em um ambiente real de TPU
- Exemplos concretos de batch size, forma de paralelização e uso de memória
-
Seção 7: Tudo sobre inferência com Transformers
- Na inferência, a latência surge como um novo fator importante
- Uso de memória e problemas de comunicação causados por itens como o cache KV
- Discussão sobre como distribuir e conectar vários chips para servir o modelo
-
Seção 8: Aplicação de serving do LLaMA 3 em TPU
- Análise aproximada de trade-offs entre custo, latência e throughput ao assumir o serving do LLaMA 3 em TPU v5e
-
-
Parte 3: Practical Tutorials
-
Seção 9: Como fazer profiling de código em TPU
- Entender a stack JAX+XLA
- Identificação de problemas reais de degradação de desempenho e soluções
- Como usar o profiler do JAX/TensorBoard
-
Seção 10: Programando TPUs com JAX
- Como usar as APIs de paralelização (primitives) do JAX
- Aprendizado dos conceitos de computação paralela por meio de exemplos e exercícios
-
Seção 11: Conclusão e materiais adicionais
- Leituras adicionais sobre TPUs e LLMs
- Encerramento breve do conteúdo geral, com menção a perspectivas futuras
-
1 comentários
Comentários do Hacker News