2 pontos por GN⁺ 2024-09-24 | 1 comentários | Compartilhar no WhatsApp

Felafax BlogTune Llama3 405B on AMD MI300x (nossa jornada)

Introdução

  • À medida que os modelos open source ficam maiores, cresce a necessidade de uma infraestrutura robusta para lidar com o treinamento de IA em larga escala
  • A Felafax demonstrou a eficiência do hardware AMD ao fazer o ajuste fino do modelo LLaMA 3.1 405B em GPUs AMD
  • Todo o trabalho foi disponibilizado como open source no GitHub
  • As GPUs AMD MI300X oferecem alto desempenho em comparação com o hardware de IA da NVIDIA
  • O projeto foi possível com o apoio da TensorWave

O que é JAX e por que escolhê-lo

  • JAX é uma poderosa biblioteca de machine learning que combina uma API semelhante ao NumPy, diferenciação automática e o compilador XLA do Google
  • Oferece excelentes APIs para paralelismo de modelos, sendo ideal para o treinamento de modelos em larga escala

Vantagens do JAX

  • Funções puras: o JAX incentiva a escrita de funções puras, o que facilita compor, depurar e ler o código
  • Paralelismo avançado: a API flexível de JIT do JAX oferece suporte a paralelismo avançado de dados e de modelos, essencial para treinamento em larga escala
  • Base de código limpa: a filosofia de design do JAX incentiva a escrita de código portátil entre diferentes plataformas de hardware

Por que o JAX se destaca em hardware não NVIDIA

  • Abordagem independente de hardware: o JAX usa o compilador XLA para compilar cálculos em uma representação intermediária independente de hardware
  • Otimização independente de plataforma: o compilador XLA realiza otimizações independentemente do hardware
  • Portabilidade simples: com JAX, a mudança de NVIDIA para AMD exige alterações mínimas no código

Configurando JAX em GPUs AMD

  • Baixaram a imagem Docker, iniciaram o contêiner e verificaram a instalação
  • Treinaram o modelo LLaMA 405B usando 8 GPUs AMD MI300X

Treinamento do LLaMA 405B: desempenho e escalabilidade

  • Treinaram o modelo LLaMA 405B em GPUs AMD usando JAX
  • Com ajuste fino via LoRA, ajustaram os pesos do modelo e os parâmetros LoRA com precisão bfloat16
  • Tamanho do modelo: ocupa cerca de 800 GB de VRAM
  • Pesos LoRA e estado do otimizador: ocupam cerca de 400 GB de VRAM
  • Uso total de VRAM: cerca de 1200 GB
  • Velocidade de treinamento: cerca de 35 tokens por segundo
  • Eficiência de memória: mantida em cerca de 70%
  • Escalabilidade: com JAX, escalou quase linearmente em 8 GPUs

Nossa configuração de treinamento

  • Converteram o LLaMA 3.1 de PyTorch para JAX
  • Distribuíram o modelo com eficiência por meio de carregamento do modelo e sharding de parâmetros

Sharding de parâmetros no JAX

  • Usaram o recurso de malha de dispositivos do JAX para distribuir o modelo com eficiência entre 8 GPUs AMD
  • Definiram regras de sharding de parâmetros para fragmentar a dimensão de cada tensor de acordo com os eixos da malha

Implementação do treinamento com LoRA

  • LoRA reduz o número de parâmetros treináveis ao decompor as atualizações de pesos em matrizes de baixa ordem
  • Implementaram uma camada LoRADense incluindo parâmetros LoRA
  • Distribuíram os parâmetros LoRA de forma eficiente para otimizar o uso de memória e a eficiência computacional

Conclusão

  • A experiência de fazer ajuste fino do modelo LLaMA 3.1 405B com GPUs AMD e JAX foi muito positiva
  • Aproveitaram os poderosos recursos de paralelismo do JAX e sua abordagem independente de hardware para distribuir o modelo com eficiência
  • Demonstraram que as GPUs AMD são uma alternativa robusta para treinamento de IA em larga escala
  • É possível conferir o código completo no repositório do GitHub e executá-lo por conta própria

Resumo do GN⁺

  • Este artigo explica como treinar com eficiência modelos de IA em larga escala usando GPUs AMD e JAX
  • Destaca o hardware AMD como uma alternativa com melhor custo-benefício em comparação com a NVIDIA
  • A abordagem independente de hardware do JAX aumenta a portabilidade do código e facilita a manutenção
  • Oferece informações úteis e código prático para quem tem interesse em treinamento de modelos em larga escala
  • Projetos com funcionalidades semelhantes incluem CUDA e PyTorch, da NVIDIA

1 comentários

 
GN⁺ 2024-09-24
Comentário do Hacker News
  • Compartilha resultados do ajuste fino do modelo Llama3.1 405B em 8 GPUs AMD MI300x usando JAX

    • Alcançou excelente desempenho graças à API avançada de sharding do JAX
    • Fornece links para o post do blog e para o código open source: link do GitHub
    • É uma startup que está construindo infraestrutura de IA para ajustar e servir LLMs em TPU, AMD e Trainium, e não em hardware NVIDIA
    • Muitas empresas tentam fazer o PyTorch rodar em GPUs AMD, mas considera que esse é um caminho difícil
    • O PyTorch está profundamente ligado ao ecossistema da NVIDIA, então fazê-lo funcionar em hardware não NVIDIA exige muitas modificações
    • Acredita que o JAX é mais adequado para hardware não NVIDIA
    • No JAX, o código do modelo de ML é compilado em um grafo HLO independente de hardware, e o compilador XLA faz as otimizações específicas para cada hardware
    • O mesmo código JAX pode rodar em Google TPU e GPU AMD sem alterações
    • A estratégia da empresa é portar modelos para JAX e usar kernels XLA para extrair o máximo desempenho em backends não NVIDIA
    • Portou o Llama 3.1 de PyTorch para JAX pela primeira vez, e agora o mesmo modelo JAX funciona bem em TPU e GPU AMD
    • Gostaria de ouvir opiniões sobre a visão e o repositório
  • Sugere explorar formas de superar as limitações de memória e executar uma versão compilada com JIT

    • Isso poderia trazer ganhos adicionais de desempenho
  • Compartilha experiência com GPUs AMD e suporte ROCm

    • Tentou usar GPUs AMD e suporte ROCm há 1 ano, mas sentiu que a AMD ainda está longe de alcançar a NVIDIA
    • Escolher JAX foi uma abordagem interessante, mas pergunta quais dificuldades houve ao sair do PyTorch
  • Compartilha experiência testando o modelo 405B do ponto de vista de inferência

    • Acha que torch.cuda não é tão ruim assim
    • Entende que é apenas uma questão de nome, já que a versão AMD do PyTorch faz essa tradução
    • Usar o contêiner rocm:pytorch é tão fácil quanto usar o contêiner rocm:jax
    • Aponta que não foram publicados muitos dados de desempenho
    • Pergunta sobre os números de MFU (taxa de utilização do modelo)
  • Pergunta sobre a ausência de dados de desempenho

    • Levanta dúvidas sobre a possibilidade de extrair valor com pedidos em grande volume de GPUs AMD
    • Fica com a impressão de que a resposta é "não"
  • Questiona por que o Obsidian (app de anotações) está fazendo isso

    • No começo, achou que era um post do Obsidian
    • Questiona por que ainda não distinguem GitHub.com de GitHub.io
  • Pede ao @dang para incluir o nome de usuário na URL

    • Este post é sobre um blog gerado por usuário, não sobre o próprio Obsidian