14 pontos por xguru 2024-08-19 | 8 comentários | Compartilhar no WhatsApp
  • O motivo pelo qual o PyTorch causa perda de produtividade e desperdício de tempo de desenvolvimento não é "porque o framework em si é ruim, mas porque ele não foi projetado para os casos de uso aos quais está sendo aplicado hoje"

A filosofia do PyTorch

  • A filosofia do PyTorch é ser dinâmico, fácil de depurar e pythônico
  • Em contraste, o TensorFlow 1.x tentou se tornar um framework estático, mas com bom desempenho, fazendo uso pesado do compilador XLA
  • Os desenvolvedores do TensorFlow perceberam que a comunidade não gostava da API 1.x, decidiram usar o Keras como interface principal e reduziram o papel do compilador XLA
  • O PyTorch manteve suas raízes e, diferentemente da abordagem estática e adiada do TensorFlow, adotou uma abordagem mais dinâmica de "execução imediata", em que torch.Tensor é avaliado imediatamente
  • Isso deu resultado, e muita pesquisa migrou para o PyTorch
  • Em 2021, com a chegada do GPT-3, desempenho e escalabilidade se tornaram preocupações centrais
  • O PyTorch respondeu razoavelmente bem a essa demanda, mas como não foi projetado com essa filosofia em mente, a dívida começou a se acumular e suas bases passaram a vacilar
  • Os desenvolvedores do PyTorch não queriam nenhum tipo de compromisso e escolheram seguir dois caminhos ao mesmo tempo
    • usar o compilador XLA como backend padrão, com alto desempenho e estabilidade
    • construir a stack torch.compile para dar aos usuários a liberdade de chamar o compilador quando necessário
  • A ausência de uma estratégia de longo prazo é um problema grave
  • O PyTorch não quer se comprometer com uma filosofia centrada em compilador, como a do JAX, mas também não parece haver uma boa alternativa
  • Qual é a solução dos concorrentes para esse problema?

Desenvolvimento baseado em compilador no JAX

  • O JAX aproveita o XLA, a poderosa stack de compilação do TensorFlow
  • O XLA é um compilador poderoso, mas tudo isso fica abstraído para o usuário final
  • Desde que uma função seja pura (pure), é possível usar o decorador @jax.jit para compilá-la em JIT e torná-la utilizável pelo XLA
  • O XLA valida se o grafo gerado está correto e, por trás dos panos, cuida do particionador GSPMD para paralelização automática com sharding no JAX, otimização de grafos, fusão de operadores e kernels, escalonamento para ocultação de latência, sobreposição assíncrona de comunicação, geração de código para outros backends como Triton etc.
  • Basta respeitar as restrições do JAX, e o XLA cuida do resto automaticamente
  • Por exemplo, ao paralelizar, não são necessários primitivos de comunicação como torch.distributed.barrier()
  • O suporte a DDP é possível com código simples
  • A abordagem do XLA é que o cálculo segue o sharding. Portanto, se um array de entrada for fragmentado ao longo de algum eixo, o XLA cuida automaticamente dos subcálculos
  • A ideia de "desenvolvimento baseado em compilador" é parecida com a forma como o compilador do Rust funciona
  • Limitações do PyTorch
    • Há insatisfação com a escolha dos desenvolvedores do PyTorch de integrar e depender de uma stack de compilação para novos recursos, em vez de manter a filosofia central de flexibilidade e liberdade
    • Segundo o roadmap oficial do PyTorch 2.x, há um plano de longo prazo claramente definido para integrar totalmente o XLA ao Torch
    • Isso é uma ideia terrível. É como dizer que forçar código C++ para dentro do compilador do Rust daria uma experiência melhor do que simplesmente usar o próprio Rust
    • O Torch, ao contrário do JAX, não foi projetado em torno do XLA
    • Se o PyTorch decidiu usar uma stack de compilação baseada em XLA, então o framework ideal não seria um projetado e construído especificamente em torno disso?
    • Mesmo que o PyTorch siga uma abordagem de "multi-backend", na qual se pode escolher o backend de compilação desejado, isso não agravaria o problema de fragmentação e não acabaria destruindo completamente a API ao tentar respeitar as limitações de todas as stacks de compilação?
    • Quem já usou o Torch/XLA em TPU sofre de um PTSD severo

Multi-Backend fracassou

  • O PyTorch tenta fazer tudo de uma vez e falha miseravelmente
  • A decisão de design de "multi-backend" piora esse problema de forma exponencial
  • Em teoria, parece que você pode escolher a stack que quiser, mas na prática é um caos emaranhado de tracebacks incompreensíveis e problemas de incompatibilidade
  • Restrições entre backends e conflito com a API do PyTorch
    • A dificuldade não está apenas em fazer esses backends funcionarem, mas no fato de que as restrições que eles exigem não combinam bem com a API flexível e pythônica do PyTorch
    • Existe um trade-off entre manter a consistência da API e seguir as limitações do backend
    • Como resultado, os desenvolvedores tentam depender mais de geração de código em vez de se integrar/comprometer de fato com um backend único
  • Ausência de estratégia no PyTorch
    • Como o PyTorch se recusa a aceitar trade-offs significativos, toda decisão parece um compromisso mal resolvido
    • Não há consistência nem uma estratégia geral
    • No fim, isso gera muita frustração para o usuário e passa a sensação de um amontoado de recursos que não combinam entre si
    • Não existe maneira mais rápida de matar um ecossistema
  • Por que não se deve seguir a abordagem do JAX
    • O PyTorch não deveria seguir a abordagem de "compilador e backend integrados" do JAX
    • Porque o JAX foi explicitamente projetado para funcionar com o XLA
    • Substituir o frontend do PyTorch pelo do JAX não pode ser a estratégia
    • É praticamente impossível inventar uma API melhor do que a do JAX com base no XLA
    • Não se critica que os desenvolvedores tentem ideias novas e diferentes
    • Mas, se o PyTorch quiser resistir ao teste do tempo, precisa se concentrar mais em fortalecer suas bases do que em oferecer novos recursos vistosos que desmoronam imediatamente fora das condições ideais dos tutoriais

A fragmentação do PyTorch e a programação funcional no JAX

  • A API funcional do JAX
    • As funções no JAX devem ser puras (pure), isto é, não devem ter efeitos colaterais globais
    • Como funções matemáticas, com os mesmos dados devem sempre retornar a mesma saída, independentemente do contexto de execução
    • Graças a essa filosofia de design, as funções do JAX são componíveis e interoperam bem entre si
    • A complexidade de desenvolvimento diminui, e as funções são definidas com assinaturas específicas e tarefas concretas bem definidas
    • Se os tipos forem respeitados, é garantido que a função funcionará imediatamente
    • Isso combina com os tipos de trabalho necessários em computação científica, especialmente em deep learning
  • Exemplo da API do optax
    • Graças à abordagem funcional, o optax tem algo chamado chain
    • Isso inclui várias funções aplicadas sequencialmente aos gradientes
    • O componente fundamental é GradientTransformation
    • Isso cria uma API poderosa e expressiva
    • Por exemplo, tarefas como fazer clipping de gradientes, calcular a EMA dos gradientes ou combinar otimizadores se tornam muito simples
  • Vantagens do design funcional
    • Outro resultado interessante do design funcional é o vmap
    • Ele significa map 'vectorized' e descreve exatamente essa função
    • É possível mapear tudo, e desde que seja vmap, o XLA faz automaticamente a fusão e a otimização
    • Ao escrever funções, não é preciso pensar na dimensão de batch
    • Basta aplicar vmap em todo o código
    • Isso significa que há menos necessidade de operações ein-*
    • Entender manipulações de tensores 2D/3D fica mais intuitivo, e a legibilidade também melhora muito
    • Como basta isolar os componentes individuais e raciocinar sobre eles, fica mais fácil escrever código complexo que realmente funciona
    • Se você respeitar as restrições de pureza e tiver apenas a assinatura correta, poderá aproveitar todos os outros benefícios, como componibilidade
  • Problemas do ecossistema PyTorch
    • No torch, sempre há a possibilidade de algo quebrar, independentemente da stack usada (FSDP + múltiplos nós + torch.compile etc.)
    • Muitas coisas precisam funcionar corretamente juntas, e se qualquer componente falhar, você vai passar até as 3 da manhã depurando
    • Como não é possível testar todas as combinações das dezenas de recursos que o PyTorch oferece, sempre haverá bugs que não foram descobertos durante o desenvolvimento
    • Sem um esforço considerável, é impossível escrever código que funcione bem
    • O ecossistema do torch ficou extremamente inchado e cheio de bugs
    • Como não existe uma abstração compartilhada, surgem novas bibliotecas e frameworks que não foram projetados para fazer interface com outras "soluções"
    • Isso logo degenera em um caos de dependências e requirements.txt
    • 70-80% das issues no GitHub e das discussões em fóruns existem simplesmente porque diferentes bibliotecas entram em erro entre si
    • Quase não há como resolver isso
  • Ausência de solução
    • Isso é um problema de OOP e de design
    • Algo básico e ao estilo PyTorch, como o PyTree, poderia ter ajudado a estabelecer uma base comum de abstração
    • Também não é possível adotar o paradigma de programação funcional
    • Se fizesse isso, acabaria convergindo para uma versão pior do JAX em desempenho, ao mesmo tempo em que quebraria a compatibilidade retroativa de todos os codebases existentes em torch
    • O PyTorch parece estar completamente quebrado nessa parte

A vantagem do JAX em reprodutibilidade

  • Tratamento de seed
    • O tratamento de seed no PyTorch não é ideal
    • Em geral, é preciso executar várias linhas de código
    • É fácil esquecer ou configurar errado
    • O JAX obriga a criar chaves explícitas e passá-las para todas as funções que precisam de aleatoriedade
    • Essa abordagem elimina completamente o problema, porque o RNG sempre é seeded estaticamente
    • Como o JAX tem sua própria versão do NumPy (jax.numpy), não é necessário configurar a seed separadamente
    • Essas pequenas decisões de QoL podem tornar muito melhor a experiência de uso do framework como um todo
  • Portabilidade
    • Um dos maiores problemas ao usar codebases em PyTorch é a falta de portabilidade
    • Codebases escritos para CUDA/GPU não funcionam bem quando executados em hardware não Nvidia, como TPU, NPU, AMD GPU etc.
    • É difícil portar código PyTorch escrito para 1 nó para múltiplos nós
    • Multi-node frequentemente exige dezenas de horas de desenvolvimento e mudanças significativas no código
    • A abordagem centrada em compilador do JAX leva vantagem aqui
    • O XLA cuida da troca entre backends de dispositivo e funciona bem em GPU/TPU/múltiplos nós/múltiplos slices com mudanças mínimas no código
    • Isso facilita para fornecedores de hardware oferecerem suporte aos seus dispositivos e torna mais simples alternar entre dispositivos
    • Nem todo mundo tem acesso ao mesmo hardware, então codebases portáveis em diferentes tipos de hardware podem ser um pequeno passo para tornar o deep learning mais acessível a iniciantes e intermediários
  • Escalonamento automático
    • Um codebase que consegue se autoescalar bem ajuda muito na reprodutibilidade
    • No caso ideal, isso deveria acontecer automaticamente com mudanças mínimas no código, sem depender de limites de rede
    • O JAX faz isso bem
    • Ao escrever código em JAX, não é necessário especificar primitivas de comunicação nem espalhar torch.distributed.barrier() por todo lado
    • O XLA insere isso automaticamente levando em conta o hardware disponível
    • Todo dispositivo que o JAX conseguir detectar é usado automaticamente, independentemente de rede, topologia, configuração etc.
    • Ele sincroniza e prepara o cálculo automaticamente e aplica passes de otimização para maximizar a execução assíncrona dos kernels e minimizar a latência
    • A única coisa que a pessoa precisa fazer é especificar o sharding dos tensores que deseja distribuir pelos dispositivos, como a dimensão de batch dos arrays de entrada
    • Por causa da abordagem do XLA de que "o cálculo segue o sharding", ele descobre automaticamente o restante
    • Isso permite executar facilmente experimentos validados em escala como hobby, testar e potencialmente iterar
    • Isso pode facilitar a redescoberta de ideias esquecidas e incentivar esses experimentos, já que fica fácil testá-los como funções em escala maior com esforço mínimo

Desvantagens do JAX

  • Estrutura de governança
    • Atualmente, o XLA está sob a governança do TensorFlow
    • Houve discussões sobre estabelecer um órgão organizacional separado, semelhante ao do PyTorch, mas pouco esforço concreto foi feito
    • A confiança no Google não é alta por causa da reputação de encerrar produtos impopulares
    • O JAX é tecnicamente um projeto da DeepMind e tem importância central para a estratégia geral de IA do Google, mas um arranjo de longo prazo pareceria trazer grandes benefícios para todo o ecossistema
    • Um órgão de governança separado daria direcionamento ao desenvolvimento do projeto
    • Isso forneceria uma estrutura concreta e evitaria muitos problemas de uma só vez ao separar o projeto da notória burocracia do Google
    • Não é que o JAX necessariamente precise desse tipo de estrutura formal, mas seria bom ter a garantia de que o desenvolvimento do JAX continuará por muito tempo, independentemente das decisões da alta gestão do Google
    • Isso claramente ajudaria na adoção por empresas e grandes laboratórios de pesquisa que hesitam em investir recursos na integração de uma ferramenta que um dia pode deixar de ser mantida
  • A transição open source do XLA
    • Durante muito tempo, o XLA foi um projeto de código fechado
    • No entanto, houve esforços para torná-lo open source, e hoje o OpenXLA mostra desempenho muito superior ao build interno do XLA
    • Ainda assim, a documentação sobre o interior do XLA continua insuficiente
    • A maior parte dos recursos são palestras ao vivo e alguns artigos ocasionais, frequentemente desatualizados
    • Se existisse um roadmap publicamente acessível sobre recursos planejados, seria mais fácil para as pessoas acompanhar o progresso e contribuir com o que acharem especialmente interessante
    • Seria bom ter mini posts de blog no estilo de Edward Yang analisando cada etapa da stack de compilação do XLA e explicando os detalhes, oferecendo uma forma para praticantes avaliarem melhor o que o XLA pode e não pode fazer
    • Entende-se que isso consome muitos recursos e que talvez esses recursos possam ser melhor usados em outro lugar, mas as pessoas confiam mais nas ferramentas quando as entendem, e isso teria efeitos positivos em cascata por todo o ecossistema, beneficiando todo mundo
  • Integração do ecossistema
    • O flax é um pé no sapato no ecossistema do JAX
    • Tem uma API pouco intuitiva, uma sintaxe concisa e é um inferno absoluto para iniciantes que vêm do PyTorch
    • É recomendável usar equinox
    • Houve tentativas da equipe de desenvolvimento de corrigir os problemas do flax, mas no fim isso é perda de tempo
    • Se você quer uma API no estilo do equinox, é melhor simplesmente usar equinox
    • Não há muita coisa em que o flax seja especialmente melhor, e não é difícil reproduzi-la com equinox
    • Atualmente, grande parte do ecossistema JAX é projetada em torno do flax
    • O equinox é interoperável com todas as bibliotecas porque faz interface fundamentalmente com PyTree, embora exija um pouco de eqx.partition e filter
    • Seria bom mudar o status quo. O equinox deveria receber suporte de primeira classe em todo lugar
    • Essa é uma opinião controversa, mas isso é a clássica falácia do custo afundado
    • O equinox funciona melhor da forma como o framework JAX sempre deveria ter funcionado
    • Como resumido na documentação do equinox, ao comparar equinox e flax, o equinox é melhor
    • É bom ver mantenedores do ecossistema JAX reconhecendo a popularidade do equinox e se ajustando a isso, mas seria ótimo ver também mais apoio oficial do Google e da equipe do flax
    • Se você quiser experimentar o JAX, é recomendável usar equinox
  • Pontas afiadas
    • Por causa de decisões de design da API e das restrições do XLA, o JAX tem algumas "pontas afiadas" às quais é preciso prestar atenção
    • Isso é explicado de forma muito concisa em uma documentação bem escrita
    • É recomendável lê-la pelo menos uma vez antes de usar o JAX
    • Como sempre, fazer RTFM vai economizar muito tempo e energia

Conclusão

  • Este post de blog tinha o objetivo de corrigir o mito frequentemente repetido de que o PyTorch seria o mais adequado para cargas reais de pesquisa, especialmente em GPU. Isso não é mais verdade
  • Na verdade, vai ao ponto de argumentar de forma extrema que portar todo o código em PyTorch para JAX seria enormemente benéfico para a área como um todo
    • paralelização automática, reprodutibilidade, API funcional limpa etc. não são recursos triviais e ajudariam muito muitos codebases de pesquisa
  • Se você quiser tornar esta área um pouco melhor, considere reescrever seu codebase em JAX

8 comentários

 
xguru 2024-08-25

O mundo continua girando. hehe

Comparação entre PyTorch e TensorFlow em 2022

 
hilft 2024-08-21

Vou continuar me virando com torch e onnx.

 
flrngel 2024-08-21

Texto escrito por um estudante de graduação... caramba

 
cosine20 2024-08-21

Se não fosse o Hugging Face, o PyTorch já estaria morto mesmo kkk

 
lemonmint 2024-08-19

Viva o JAX! Usei recentemente e gostei muito da API NNX.

 
stareta1202 2024-08-19

O maior problema do JAX é que ele é do Google. O Google é bastante famoso por abandonar open source (Tflite, android things, dart, angular, bazel etc.). Até o TensorFlow, em algum momento, começou a receber bem menos atualizações. Já o torch começou no Facebook, que opera um ecossistema enorme de open source, e vem sendo muito bem mantido, além de já ser administrado pela fundação do torch. Os pontos fracos do torch certamente fazem sentido em vários aspectos, mas, na questão de quem consegue manter esse open source de forma sustentável, o JAX parece já começar carregando um grande risco.

 
dalinaum 2024-08-20

Pelo menos o Dart parece que vai continuar bem vivo por um bom tempo com o Flutter.

 
ilotoki0804 2024-08-20

O Facebook, com React e Django, ainda parece contribuir continuamente com certa lealdade (?) para a própria stack tecnológica que usa, mas o Google parece descartar como um trapo velho assim que algo fica um pouco obsoleto...