-
Importância do Attention
- Attention é a camada central da arquitetura Transformer e causa gargalos em modelos de linguagem de grande porte e aplicações com contexto longo.
- FlashAttention e FlashAttention-2 abriram caminho para uma abordagem que acelera o Attention em GPUs ao minimizar leituras e escritas de memória.
- Com isso, o comprimento de contexto dos LLMs aumentou significativamente.
-
Principais tecnologias do FlashAttention-3
- Uso de assincronia: aproveita a assincronia dos Tensor Cores e do TMA para sobrepor toda a computação e a movimentação de dados.
- Operações por bloco: alterna multiplicação de matrizes e operações de softmax em nível de bloco.
- Processamento em baixa precisão: melhora o desempenho com suporte a baixa precisão FP8.
-
Melhorias de desempenho do FlashAttention-3
- Eficiência no uso da GPU: utiliza até 75% do desempenho máximo da GPU H100 e é de 1,5 a 2 vezes mais rápido que a versão anterior.
- Desempenho em baixa precisão: usa FP8 para aumentar a velocidade de processamento e reduzir o uso de memória.
- Processamento de contexto longo: acelera o mecanismo de Attention para processar textos mais longos com eficiência.
-
Resumo do FlashAttention
- FlashAttention reorganiza o cálculo de Attention e usa tiling e recomputação para aumentar muito a velocidade e reduzir o uso de memória.
- Com tiling, carrega blocos de entrada, executa Attention nesses blocos e depois atualiza a saída.
- Ao não gravar a matriz intermediária de Attention na memória, reduz o volume de leituras e escritas de memória.
-
Novos recursos de hardware da GPU Hopper
- WGMMA: usa novos Tensor Cores para fornecer alta taxa de processamento.
- TMA: unidade de hardware que acelera a transferência de dados entre memória global e memória compartilhada.
- FP8 de baixa precisão: usa FP8 para dobrar a taxa de processamento dos Tensor Cores.
-
Assincronia: sobreposição de GEMM e Softmax
- Necessidade da sobreposição: executa GEMM e softmax em paralelo para maximizar o desempenho.
- Agendamento ping-pong: dois grupos de warps alternam a execução de GEMM e softmax para melhorar o desempenho.
- Sobreposição dentro do grupo de warps: executa GEMM e softmax em paralelo dentro do mesmo grupo de warps para aumentar a taxa de processamento.
-
Baixa precisão: redução do erro de quantização com processamento incoerente
- Processamento incoerente: usa a transformada de Hadamard para reduzir o erro de quantização.
- Resultados experimentais: o processamento incoerente reduziu o erro de quantização em 2,6 vezes.
-
Benchmark de Attention
- FP16: cerca de 1,6 a 1,8 vezes mais rápido que o FlashAttention-2.
- FP8: alcança até 1,2 PFLOPS.
Resumo do GN⁺
- FlashAttention-3 melhora significativamente o desempenho do mecanismo de Attention ao aproveitar novos recursos de hardware das GPUs.
- Como consegue processar contextos longos com eficiência, maximiza o desempenho de modelos de linguagem de grande porte.
- Há grande chance de integração com frameworks importantes como PyTorch, o que deve ter forte impacto em pesquisas e aplicações de IA no futuro.
- Projetos com funcionalidades semelhantes incluem Triton e cuDNN.
1 comentários
Comentários no Hacker News
Parece que Tri Dao começou a trabalhar no FA3 em abril de 2022
Fica a curiosidade sobre o quanto o algoritmo Flash Attention depende do hardware
Fica a dúvida se compiladores conseguirão encontrar sozinhos otimizações como as do FlashAttention
Pedem que quem quiser portar para ROCm/AMD MI300x entre em contato
TMA (Tensor Memory Accelerator) é uma unidade de hardware que acelera a transferência de dados entre memória global e memória compartilhada
FlashAttention-3 é otimizado para GPUs Hopper (como a H100)
É mencionado que funções de ativação como sigmoid são muito lentas em LLMs modernos
Fica a dúvida sobre por que o Flash Attention é 5 vezes mais lento com masking variável do que sem ele
Fica a dúvida se FlashAttention pode substituir a operação de attention em LLMs
É necessário hardware caro