Ajuste fino do Llama 405B com GPU AMD
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (nossa jornada)
Introdução
- À medida que os modelos open source ficam maiores, cresce a necessidade de uma infraestrutura robusta para lidar com o treinamento de IA em larga escala
- A Felafax demonstrou a eficiência do hardware AMD ao fazer o ajuste fino do modelo LLaMA 3.1 405B em GPUs AMD
- Todo o trabalho foi disponibilizado como open source no GitHub
- As GPUs AMD MI300X oferecem alto desempenho em comparação com o hardware de IA da NVIDIA
- O projeto foi possível com o apoio da TensorWave
O que é JAX e por que escolhê-lo
- JAX é uma poderosa biblioteca de machine learning que combina uma API semelhante ao NumPy, diferenciação automática e o compilador XLA do Google
- Oferece excelentes APIs para paralelismo de modelos, sendo ideal para o treinamento de modelos em larga escala
Vantagens do JAX
- Funções puras: o JAX incentiva a escrita de funções puras, o que facilita compor, depurar e ler o código
- Paralelismo avançado: a API flexível de JIT do JAX oferece suporte a paralelismo avançado de dados e de modelos, essencial para treinamento em larga escala
- Base de código limpa: a filosofia de design do JAX incentiva a escrita de código portátil entre diferentes plataformas de hardware
Por que o JAX se destaca em hardware não NVIDIA
- Abordagem independente de hardware: o JAX usa o compilador XLA para compilar cálculos em uma representação intermediária independente de hardware
- Otimização independente de plataforma: o compilador XLA realiza otimizações independentemente do hardware
- Portabilidade simples: com JAX, a mudança de NVIDIA para AMD exige alterações mínimas no código
Configurando JAX em GPUs AMD
- Baixaram a imagem Docker, iniciaram o contêiner e verificaram a instalação
- Treinaram o modelo LLaMA 405B usando 8 GPUs AMD MI300X
Treinamento do LLaMA 405B: desempenho e escalabilidade
- Treinaram o modelo LLaMA 405B em GPUs AMD usando JAX
- Com ajuste fino via LoRA, ajustaram os pesos do modelo e os parâmetros LoRA com precisão bfloat16
- Tamanho do modelo: ocupa cerca de 800 GB de VRAM
- Pesos LoRA e estado do otimizador: ocupam cerca de 400 GB de VRAM
- Uso total de VRAM: cerca de 1200 GB
- Velocidade de treinamento: cerca de 35 tokens por segundo
- Eficiência de memória: mantida em cerca de 70%
- Escalabilidade: com JAX, escalou quase linearmente em 8 GPUs
Nossa configuração de treinamento
- Converteram o LLaMA 3.1 de PyTorch para JAX
- Distribuíram o modelo com eficiência por meio de carregamento do modelo e sharding de parâmetros
Sharding de parâmetros no JAX
- Usaram o recurso de malha de dispositivos do JAX para distribuir o modelo com eficiência entre 8 GPUs AMD
- Definiram regras de sharding de parâmetros para fragmentar a dimensão de cada tensor de acordo com os eixos da malha
Implementação do treinamento com LoRA
- LoRA reduz o número de parâmetros treináveis ao decompor as atualizações de pesos em matrizes de baixa ordem
- Implementaram uma camada
LoRADenseincluindo parâmetros LoRA - Distribuíram os parâmetros LoRA de forma eficiente para otimizar o uso de memória e a eficiência computacional
Conclusão
- A experiência de fazer ajuste fino do modelo LLaMA 3.1 405B com GPUs AMD e JAX foi muito positiva
- Aproveitaram os poderosos recursos de paralelismo do JAX e sua abordagem independente de hardware para distribuir o modelo com eficiência
- Demonstraram que as GPUs AMD são uma alternativa robusta para treinamento de IA em larga escala
- É possível conferir o código completo no repositório do GitHub e executá-lo por conta própria
Resumo do GN⁺
- Este artigo explica como treinar com eficiência modelos de IA em larga escala usando GPUs AMD e JAX
- Destaca o hardware AMD como uma alternativa com melhor custo-benefício em comparação com a NVIDIA
- A abordagem independente de hardware do JAX aumenta a portabilidade do código e facilita a manutenção
- Oferece informações úteis e código prático para quem tem interesse em treinamento de modelos em larga escala
- Projetos com funcionalidades semelhantes incluem CUDA e PyTorch, da NVIDIA
1 comentários
Opiniões no Hacker News
Recentemente, fizemos o fine-tuning do modelo llama3.1 405B em 8 GPUs AMD MI300x usando JAX em vez de PyTorch
Tivemos bom desempenho graças às APIs avançadas de sharding do JAX, e resumimos no blog as técnicas de sharding usadas. Também abrimos o código: https://github.com/felafax/felafax
Somos uma pequena startup construindo infraestrutura de IA para fine-tuning e serving de LLMs em hardware que não é da NVIDIA (TPU, AMD, Trainium)
Muitas empresas tentam rodar PyTorch em GPUs AMD, mas o PyTorch é profundamente entrelaçado com o ecossistema da NVIDIA, como em
torch.cudaouscaled_dot_product_attention, então achamos que ele exige muito trabalho de “des-NVIDIAzação”Acreditamos que o JAX se adapta melhor a hardware que não é da NVIDIA, porque o código do modelo é compilado para um grafo HLO independente de hardware e, depois, o compilador XLA o otimiza e aplica otimizações específicas de hardware. O mesmo código JAX do LLaMA3 funcionou sem modificações no Google TPU e em GPUs AMD
A estratégia da empresa é primeiro portar os modelos para JAX e, depois, usar o framework JAX e kernels XLA para extrair o máximo desempenho de backends que não sejam NVIDIA. Por isso, primeiro migramos o Llama 3.1 de PyTorch para JAX, e o mesmo modelo JAX funciona bem em TPU e GPUs AMD
Pessoalmente, o principal motivo para eu usar PyTorch é que o modelo original foi feito em PyTorch. Mesmo que a lógica pareça igual entre versões diferentes do modelo, em uma escala enorme de dados, erros minúsculos de ponto flutuante podem se acumular e causar drift do modelo
Depurar esse tipo de divergência de precisão em modelos grandes foi algo mais parecido com um sofrimento maior que o décimo círculo do inferno
hipblaslte Composable Kernel FANão conheço muito bem JAX, mas acredito que boa parte do motivo pelo qual o desempenho de treino com PyTorch no MI300x é péssimo se deve à lentidão das bibliotecas ROCm usadas internamente
Quando digo funcionar, não quero dizer passar duas semanas lutando com drivers e depois ficar impossibilitado de atualizar o servidor de novo
Também tenho curiosidade sobre os problemas técnicos encontrados
Para ser claro, esse desempenho é bem ruim. Provavelmente parece ser porque não conseguiram fazer a compilação funcionar direito
No modelo 405B, eles chegam a 35 tokens/s, o que corresponde a cerca de 85 teraflops. Oito GPUs MI300x ficam na faixa de 10,4 petaflops, então o MFU é de cerca de 0,8%
Isso é 40 a 50 vezes menor que um desempenho de treino decente, de 30% a 40% de MFU, então a AMD provavelmente torce para que o gargalo seja a stack de software
A página do GitHub diz que é possível ajustar o LLaMa3.1 no Google Cloud TPU com custo 30% menor, mas não menciona desempenho
Excelente trabalho. Há cerca de um ano, mexi um pouco com GPUs AMD e suporte a ROCm, e ficou claro que a AMD ainda tinha um longo caminho para alcançar a Nvidia
A abordagem de escolher JAX é interessante, mas fico curioso sobre quais dificuldades vocês tiveram ao se afastar do PyTorch, que é quase a biblioteca padrão de machine learning
No começo, o objetivo era fazer fine-tuning do LLaMA 3 em TPU, mas o PyTorch XLA era desajeitado, então decidimos reescrever o modelo em JAX
Como mencionei, vemos o JAX como uma plataforma melhor para GPUs que não sejam da NVIDIA, e queremos construir infraestrutura para GPUs não NVIDIA sobre JAX+openXLA
Bom trabalho. No fim de semana passado, eu também estava mexendo com a parte de inferência do 405B [0]
Não estou convencido de que
torch.cudaseja tão ruim assim. O PyTorch para AMD faz essa conversão por baixo. Parece mais um problema de nomenclatura do que algo essencialNa prática, puxar o contêiner
rocm:pytorché tão fácil quanto puxar o contêinerrocm:jaxNão há muitos números publicados; fico curioso para saber qual foi o MFU
[0] https://x.com/HotAisle/status/1837580046732874026
O MFU precisa ser calculado. Detalhes de GPU e VRAM podem ser vistos no repositório: https://dub.sh/amd-405b-res
No próximo fim de semana, pretendo tentar rodar o treino de novo, compilando com JIT toda a etapa de treino, e calcular o MFU nessa ocasião
Quando medimos na ZML, a MI300X foi 30% mais rápida que a H100. São chips excelentes
Fico curioso se há algum provedor de nuvem onde seja possível alugar um host 8xAMD MI300
Uso muito AWS no trabalho e queria experimentar GPUs AMD
Onde estão os dados de desempenho?
Por causa do código e das restrições de VRAM, não conseguimos executar a versão compilada com JIT do modelo 405B. Precisamos investigar mais essa parte
A execução completa do treino foi feita no modo eager do JAX, então há bastante margem para melhorar o desempenho
Mesmo no modo eager, o uso de GPU ficou em geral por volta de 30% a 40%, o que é bem razoável. Com JIT, acho que o uso de GPU poderia subir facilmente para 50% a 60%
Se possível, seria interessante explorar uma forma de superar as limitações de memória e executar a versão compilada com JIT. Isso pode levar a melhorias adicionais de desempenho
Precisamos de etapa de treino compilada com JIT, carregamento de dados e sharding mais otimizados, acumulação de gradientes e activation checkpointing
Vamos continuar construindo, implementar todas as melhorias e publicar outro post no blog em breve
Fico curioso se a AMD chegou minimamente perto de extrair valor disso por meio de grandes pedidos de GPUs e escassez de oferta
Minha impressão é mais para “não”
O outro lado tem uma vantagem inicial enorme, e claramente há muito trabalho a fazer no lado de software. Vai levar tempo
Por que o Obsidian, que é um app de notas, está fazendo isso?