Stable Diffusion 3.5 é reimplementado do zero em PyTorch puro
(github.com/yousef-rafat)- O projeto miniDiffusion é um open source que reimplementa o modelo Stable Diffusion 3.5 do zero usando apenas PyTorch
- A estrutura do projeto se destaca por ser focada em fins educacionais e em experimentos e hacking
- Toda a base de código tem cerca de 2800 linhas, composta pelo mínimo de código necessário, do VAE ao DiT, além de scripts de treinamento e dataset
- Entre os principais componentes estão VAE, CLIP, codificadores de texto T5, transformer de difusão multimodal e atenção conjunta
- Ainda inclui funcionalidades experimentais, portanto precisa de mais testes
Introdução ao projeto miniDiffusion
miniDiffusion é um projeto open source que reimplementa os recursos centrais do Stable Diffusion 3.5 usando apenas PyTorch
Em comparação com o Stable Diffusion 3.5 original, este projeto tem as seguintes vantagens:
- A base de código tem cerca de 2.800 linhas, é pequena e muito adequada para analisar a arquitetura diretamente e aprender com ela
- Pode ser usada de forma útil em vários experimentos de machine learning e em hacking de modelos
- Tem pouquíssimas dependências e usa apenas o conjunto mínimo de bibliotecas
Estrutura principal e arquivos de composição
- dit.py : implementação principal do modelo Stable Diffusion
- dit_components.py : embeddings, normalização, patch embedding e funções auxiliares do DiT
- attention.py : implementação do algoritmo Joint Attention (atenção conjunta)
- noise.py : inclui o scheduler Euler ODE para Rectified Flow
- t5_encoder.py, clip.py : implementação dos codificadores de texto T5 e CLIP
- tokenizer.py : implementação dos tokenizers Byte-Pair e Unigram
- metrics.py : implementação da métrica de avaliação FID (Fréchet inception distance)
- common.py : fornece funções auxiliares necessárias para o treinamento
- common_ds.py : implementação de dataset iterável que converte imagens em dados de treino para o DiT
- pasta model : armazena checkpoints e logs do modelo após o treinamento
- pasta encoders : armazena checkpoints de módulos separados como VAE e CLIP
⚠️ Funcionalidades experimentais e necessidade de testes O miniDiffusion ainda inclui funcionalidades experimentais e precisa de mais testes
Composição detalhada por funcionalidade principal
Core Image Generation Modules
- Implementação de VAE, CLIP e codificadores de texto T5
- Implementação de tokenizers Byte-Pair e Unigram
Componentes do SD3
- Multi-Modal Diffusion Transformer Model
- Implementação de Flow-Matching Euler Scheduler
- Logit-Normal Sampling
- Introdução do algoritmo Joint Attention
Scripts de treinamento e inferência do modelo
- Fornece scripts de treinamento e inferência para SD3 (Stable Diffusion 3.5)
Licença
- Disponibilizado sob a licença MIT e criado para fins educacionais e experimentais
Significado e vantagens deste projeto open source
- Permite treinar e hackear diretamente, com PyTorch puro, uma arquitetura moderna de geração de imagens no nível do Stable Diffusion 3.5
- O código é conciso e independente, sendo otimizado para análise de arquitetura, ajuste de modelo e pesquisa de novos algoritmos
- Permite praticar diretamente técnicas modernas de multimodalidade, transformers e atenção
- Oferece uma base para experimentar com segurança, separada de projetos comerciais
1 comentários
Comentários do Hacker News
A implementação de referência do Flux tem uma estrutura realmente minimalista, então vale a pena dar uma olhada para quem tiver interesse
GitHub do Flux
A vantagem do projeto minRF é que ele usa rectified flow, o que facilita começar a treinar modelos pequenos de difusão
GitHub do minRF
A implementação de referência do Stable Diffusion 3.5 também foi escrita de forma bem enxuta, então é adequada como material de referência
GitHub do SD 3.5
Implementações de referência muitas vezes não são bem mantidas e têm muitos bugs
cudagraphse afinsFiquei curioso se isso significa que o projeto miniDiffusion usa o modelo Stable Diffusion 3.5
Código relacionado
O dataset de treinamento é muito pequeno e contém apenas fotos relacionadas a moda
Dataset de moda
Esse dataset serve para praticar fine-tuning de modelos de difusão
Fico me perguntando se usar PyTorch puro traz alguma vantagem de desempenho em GPUs que não sejam da NVIDIA, ou se o PyTorch é tão otimizado para CUDA que outros fabricantes de GPU não conseguem competir
O PyTorch funciona razoavelmente bem também no Apple Silicon
Também é possível rodar workloads de ML em dispositivos não NVIDIA, como AMD, via Vulkan
O suporte do PyTorch a ROCm está avançando muito lentamente, e mesmo quando funciona a velocidade é baixa
O PyTorch até funciona bem em ROCm, mas não sei se chega a funcionar em nível totalmente “equivalente”
No código PyTorch, em vez de
talvez fosse interessante tentar algo como
como sugestão
Fazendo isso, em vez de os parâmetros originais de q, k e v ficarem conectados de forma independente, os parâmetros entre q, k e v passam a ficar conectados
Parece ser um ótimo material para quem está aprendendo
Fico curioso se existe algum tutorial ou guia que até iniciantes consigam acompanhar
Há uma aula da fast.ai em que eles implementam Stable Diffusion diretamente
Fiquei me perguntando se isso quer dizer que é possível usar Stable Diffusion sem restrições de licença
Sendo sincero, talvez até com um pouco de vergonha, fico pensando no que exatamente ganhamos antes e depois de esses repositórios existirem
Pessoalmente sempre evitei construir modelos e só acompanhei os resultados de fora
Eu imaginava vagamente que já existiam scripts públicos de inferência/treinamento em PyTorch
Pelo menos os scripts de inferência eu supunha que viriam junto com a distribuição do modelo, e achei que também existiriam scripts para fine-tuning/treinamento
Não tenho certeza se este projeto é uma reescrita em estilo “clean room” ou “dirty room”, ou se até o código PyTorch existente já era tão complexo por causa de CUDA/C que uma versão em PyTorch puro faz uma grande diferença
Enfim, eu realmente não sei, então ficaria grato se alguém pudesse explicar
O principal valor deste projeto é ser uma “implementação com dependências mínimas”
transformers, o que na prática é bem complicado para uso profissionalA Stability AI distribui os modelos Stable Diffusion sob a Stability AI Community License, então, ao contrário da MIT, não é algo “totalmente livre”
Quando penso no SD 3.5 (ou em qualquer versão), considero que a parte central são os pesos gerados no processo de treinamento
Fico curioso sobre a usabilidade prática do código-fonte acadêmico original publicado pelo grupo CompViz da Ludwig Maximilian University
Também fiquei curioso se a implementação de diffusion transformer (DiT) daqui implementa corretamente cross-token attention como na versão completa do SD 3.5, ou se isso foi simplificado para facilitar a legibilidade do código