Felafax BlogTune Llama3 405B on AMD MI300x (nossa jornada)
Introdução
- À medida que os modelos open source ficam maiores, cresce a necessidade de uma infraestrutura robusta para lidar com o treinamento de IA em larga escala
- A Felafax demonstrou a eficiência do hardware AMD ao fazer o ajuste fino do modelo LLaMA 3.1 405B em GPUs AMD
- Todo o trabalho foi disponibilizado como open source no GitHub
- As GPUs AMD MI300X oferecem alto desempenho em comparação com o hardware de IA da NVIDIA
- O projeto foi possível com o apoio da TensorWave
O que é JAX e por que escolhê-lo
- JAX é uma poderosa biblioteca de machine learning que combina uma API semelhante ao NumPy, diferenciação automática e o compilador XLA do Google
- Oferece excelentes APIs para paralelismo de modelos, sendo ideal para o treinamento de modelos em larga escala
Vantagens do JAX
- Funções puras: o JAX incentiva a escrita de funções puras, o que facilita compor, depurar e ler o código
- Paralelismo avançado: a API flexível de JIT do JAX oferece suporte a paralelismo avançado de dados e de modelos, essencial para treinamento em larga escala
- Base de código limpa: a filosofia de design do JAX incentiva a escrita de código portátil entre diferentes plataformas de hardware
Por que o JAX se destaca em hardware não NVIDIA
- Abordagem independente de hardware: o JAX usa o compilador XLA para compilar cálculos em uma representação intermediária independente de hardware
- Otimização independente de plataforma: o compilador XLA realiza otimizações independentemente do hardware
- Portabilidade simples: com JAX, a mudança de NVIDIA para AMD exige alterações mínimas no código
Configurando JAX em GPUs AMD
- Baixaram a imagem Docker, iniciaram o contêiner e verificaram a instalação
- Treinaram o modelo LLaMA 405B usando 8 GPUs AMD MI300X
Treinamento do LLaMA 405B: desempenho e escalabilidade
- Treinaram o modelo LLaMA 405B em GPUs AMD usando JAX
- Com ajuste fino via LoRA, ajustaram os pesos do modelo e os parâmetros LoRA com precisão bfloat16
- Tamanho do modelo: ocupa cerca de 800 GB de VRAM
- Pesos LoRA e estado do otimizador: ocupam cerca de 400 GB de VRAM
- Uso total de VRAM: cerca de 1200 GB
- Velocidade de treinamento: cerca de 35 tokens por segundo
- Eficiência de memória: mantida em cerca de 70%
- Escalabilidade: com JAX, escalou quase linearmente em 8 GPUs
Nossa configuração de treinamento
- Converteram o LLaMA 3.1 de PyTorch para JAX
- Distribuíram o modelo com eficiência por meio de carregamento do modelo e sharding de parâmetros
Sharding de parâmetros no JAX
- Usaram o recurso de malha de dispositivos do JAX para distribuir o modelo com eficiência entre 8 GPUs AMD
- Definiram regras de sharding de parâmetros para fragmentar a dimensão de cada tensor de acordo com os eixos da malha
Implementação do treinamento com LoRA
- LoRA reduz o número de parâmetros treináveis ao decompor as atualizações de pesos em matrizes de baixa ordem
- Implementaram uma camada
LoRADense incluindo parâmetros LoRA
- Distribuíram os parâmetros LoRA de forma eficiente para otimizar o uso de memória e a eficiência computacional
Conclusão
- A experiência de fazer ajuste fino do modelo LLaMA 3.1 405B com GPUs AMD e JAX foi muito positiva
- Aproveitaram os poderosos recursos de paralelismo do JAX e sua abordagem independente de hardware para distribuir o modelo com eficiência
- Demonstraram que as GPUs AMD são uma alternativa robusta para treinamento de IA em larga escala
- É possível conferir o código completo no repositório do GitHub e executá-lo por conta própria
Resumo do GN⁺
- Este artigo explica como treinar com eficiência modelos de IA em larga escala usando GPUs AMD e JAX
- Destaca o hardware AMD como uma alternativa com melhor custo-benefício em comparação com a NVIDIA
- A abordagem independente de hardware do JAX aumenta a portabilidade do código e facilita a manutenção
- Oferece informações úteis e código prático para quem tem interesse em treinamento de modelos em larga escala
- Projetos com funcionalidades semelhantes incluem CUDA e PyTorch, da NVIDIA
1 comentários
Comentário do Hacker News
Compartilha resultados do ajuste fino do modelo Llama3.1 405B em 8 GPUs AMD MI300x usando JAX
Sugere explorar formas de superar as limitações de memória e executar uma versão compilada com JIT
Compartilha experiência com GPUs AMD e suporte ROCm
Compartilha experiência testando o modelo 405B do ponto de vista de inferência
torch.cudanão é tão ruim assimPergunta sobre a ausência de dados de desempenho
Questiona por que o Obsidian (app de anotações) está fazendo isso
Pede ao @dang para incluir o nome de usuário na URL