7 pontos por GN⁺ 2025-02-07 | 1 comentários | Compartilhar no WhatsApp
  • Otimizar o desempenho de deep learning em grande escala pode parecer “alquimia”, mas na prática é possível melhorar a eficiência do modelo com princípios simples e compreensíveis
  • De um único acelerador a dezenas de milhares deles, princípios relativamente simples se aplicam em todos os casos, e entendê-los permite realizar tarefas úteis como:
    • Estimar aproximadamente o quão perto cada parte do modelo está do valor ótimo teórico
    • Criar uma base para escolher diferentes técnicas de paralelização em várias escalas
    • Estimar custo e tempo necessários para treinar e executar grandes modelos Transformer
    • Projetar algoritmos que aproveitem as características de um hardware específico
    • Projetar hardware com uma compreensão clara dos limites de desempenho dos algoritmos atuais
  • Conhecimentos prévios necessários
    • É necessário entender os conceitos básicos de LLMs e da arquitetura Transformer
    • Não é obrigatório entender como operações em larga escala funcionam
    • É melhor ainda se você tiver conhecimento básico de treinamento de LLMs e experiência com JAX
    • Recomenda-se consultar o post de blog sobre a arquitetura Transformer e os slides sobre escalabilidade de LLMs em JAX
  • Objetivos
    • Desenvolver a capacidade de estimar qual forma de paralelização do modelo funciona melhor no hardware disponível
    • Desenvolver a capacidade de calcular aproximadamente o tempo e o custo de treinamento e inferência

Por que isso importa

  • Até 3 ou 4 anos atrás, a maioria dos pesquisadores de ML não precisava conhecer a fundo esse tipo de otimização em grande escala
    • Hoje, até modelos “pequenos” operam próximos dos limites do hardware, então entender formas eficientes de trabalhar em larga escala se tornou essencial
    • A história de ML pode ser vista como um processo de avanço conjunto entre inovação em sistemas e melhorias de software
    • Como os modelos Transformer recentes usam o hardware até o limite, sem entender a eficiência do modelo há grande chance de que novas arquiteturas ou pesquisas fracassem na aplicação real
    • Mesmo que se obtenha 20% de melhora em benchmark, se a eficiência de hardware cair 20%, no fim a utilidade prática será baixa
  • O objetivo central do escalonamento de modelos é fazer com que o throughput aumente linearmente ao elevar o número de chips (aceleradores)
    • Isso é chamado de "strong scaling"
    • Adicionar chips reduz o tempo de computação, mas gera custo de comunicação entre chips
    • Se a comunicação levar mais tempo do que a computação, entra-se em um estado "communication bound", no qual o strong scaling se torna impossível
    • Se for possível prever onde esses gargalos vão surgir com base em um bom entendimento do hardware, dá para projetar ou reorganizar o modelo para evitá-los
  • O objetivo deste livro é explicar como o hardware de TPU (e GPU) funciona e como a arquitetura Transformer evoluiu para funcionar bem no hardware atual
    • A expectativa é que isso seja útil tanto para pesquisadores que projetam novas arquiteturas quanto para engenheiros que tentam executar os LLMs da geração atual com rapidez

Visão geral

  • Este texto é organizado da seguinte forma
  • A Seção 1 explica, por meio de análise roofline, os fatores que determinam os limites de desempenho do modelo (comunicação, computação e memória)
  • As Seções 2 e 3 tratam da estrutura interna de TPUs e GPUs e de como os chips se conectam entre si
    • Com isso, respondem a perguntas como
      • Quão rapidamente uma multiplicação de matrizes de determinado tamanho pode ser executada em teoria
      • Em que ponto a computação passa a ficar limitada pela largura de banda de memória ou de comunicação
      • Como um cluster de TPUs é conectado e quanto tempo aproximadamente leva para mover dados de um chip para outro
      • Como multiplicar matrizes distribuídas de forma eficiente
  • A Seção 4 aborda em detalhe a formulação da arquitetura Transformer (tamanhos de matrizes, número de parâmetros, FLOPs)
  • As Seções 5 e 7 são o núcleo e apresentam diferentes formas de paralelizar modelos em vários chips
    • Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
    • Também cobre técnicas de economia de memória como ZeRO, Rematerialisation, Host offload e Gradient accumulation
  • As Seções 6 e 8 usam o modelo LLaMA-3 como exemplo de treinamento e inferência em TPUs, apresentando custo, tempo e formas de configuração na prática
  • Por fim, as Seções 9 e 10 mostram métodos práticos para fazer profiling, depuração e paralelização de modelos em JAX

Mais detalhes: resumo das principais seções do livro

1 comentários

 
GN⁺ 2025-02-07
Comentários do Hacker News
  • Há expectativa de que o JAX substitua o pytorch/cuda nos próximos anos. A questão do PTX com a equipe do Deepseek mostra o valor de investir em abordagens de nível mais baixo para extrair o máximo do desempenho do hardware
    • Isso foi usado internamente no Google como um manual para trabalho de desempenho. É surpreendente que tenha sido tornado público, mas parece que os detalhes relacionados ao Gemini foram removidos
    • Este guia é bom porque, graças ao JAX/XLA, dá para migrar diretamente para GPU
    • Há quem se pergunte por que o JAX usa tracing em vez de AST
    • Foi compartilhado um link para a thread do autor no Twitter
    • Há alguém procurando uma forma de converter um site Jekyll em PDF
    • Há elogios ao texto e agradecimentos
    • Há comentários perguntando como foram feitas as animações incríveis