- Modelos de difusão são usados além da geração de imagens, em problemas que exigem amostragem de distribuições multimodais, como áudio, vídeo, 3D, design de proteínas e planejamento de trajetórias robóticas, e este tutorial conecta treinamento e amostragem sob uma perspectiva de otimização
- O processo de treinamento cria dados com ruído misturado, (x_\sigma=x_0+\sigma\epsilon), e minimiza o erro quadrático médio para que a rede neural (\epsilon_\theta(x,\sigma)) preveja a direção do ruído
- O denoiser treinado pode ser interpretado como uma projeção aproximada sobre o conjunto de dados (\mathcal{K}), e o denoiser ideal se conecta ao gradiente da função de distância quadrática suavizada por (\sigma)
- A amostragem DDIM pode ser vista como uma descida de gradiente aproximada sobre (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2), e o cronograma de (\sigma_t) determina o número de iterações e o custo de avaliar o denoiser
- Ao combinar atualização por estimativa de gradiente com adição de ruído, é possível tratar DDIM, DDPM e o sampler aprimorado dos autores em conjunto com os parâmetros
gamemu, chegando a exemplos com modelo toy e latent diffusion
Modelos de difusão sob a ótica da otimização
- Modelos de difusão se destacam na geração de amostras a partir de distribuições multimodais, sendo aplicados não só em ferramentas de texto para imagem como Stable Diffusion, mas também em geração de áudio, vídeo, 3D, design de proteínas e planejamento de trajetórias robóticas
- A base teórica do tutorial é a interpretação por otimização do artigo da ICML 2024 e de um artigo relacionado
- A implementação toma como principal referência o
smalldiffusion, e o código do texto foi simplificado em relação à biblioteca original para fins didáticos
Treinamento: previsão da direção do ruído
- O objetivo de um modelo de difusão é aprender o conjunto de dados (\mathcal{K}) a partir de exemplos de treino e gerar amostras desse conjunto
- No caso de imagens, (\mathcal{K} \subset \mathbb{R}^{c\times h \times w}) é o conjunto de valores de pixel que correspondem a imagens realistas
- A mesma estrutura também se aplica a domínios discretos como áudio, vídeo, trajetórias robóticas e texto
- O procedimento de treinamento pode ser visto em três etapas
- Amostrar (x_0 \sim \mathcal{K}), (\sigma) e (\epsilon \sim N(0,I))
- Criar dados com ruído misturado com (x_\sigma=x_0+\sigma\epsilon)
- Minimizar a perda quadrática para que (\epsilon_\theta(x_\sigma,\sigma)) preveja (\epsilon)
- No código,
training_loopcriasigmaeepscomgenerate_train_samplepara cada batchx0e otimiza o MSE entre a saída demodel(x0 + sigma * eps, sigma)eeps - Em vez de amostrar (\sigma) uniformemente em um intervalo contínuo, ele é retirado de um cronograma de (\sigma) discretizado em (N) valores
- A classe
Scheduleencapsula a lista desigmaspossíveis e amostra valores por batch durante o treino - O exemplo do texto usa
ScheduleLogLinear(N, sigma_min=0.02, sigma_max=10) ScheduleDDPMé um cronograma para modelos de difusão no espaço de pixels, eScheduleLDMé voltado a modelos de latent diffusion como Stable Diffusion
- A classe
Exemplo toy de Swissroll
- O conjunto de dados toy é o conjunto espiralado de pontos usado em um dos primeiros artigos de difusão, Sohl-Dickstein et al. 2015, com (\mathcal{K}\subset\mathbb{R}^2)
- Em um dataset simples, o denoiser é implementado como uma MLP
- A entrada concatena (x\in\mathbb{R}^2) com um embedding bidimensional de (\sigma)
- A saída é a predição do ruído (\epsilon\in\mathbb{R}^2)
- Muitos modelos de difusão usam embeddings posicionais senoidais para (\sigma), mas neste exemplo um embedding bidimensional simples também funciona bem
- A configuração de treino do exemplo usa
ScheduleLogLinear(N=200, sigma_min=0.005, sigma_max=10)eepochs=15000 - O denoiser treinado pode ser visualizado como um campo vetorial ao se plotar (x-\sigma\epsilon_\theta(x,\sigma))
- Quando (\sigma) é grande, o denoiser tende a prever a média dos dados
- Quando (\sigma) é baixo e a entrada (x) está próxima dos dados, ele prevê pontos de dados reais
Interpretando denoising como projeção
- A função de distância ao conjunto de dados (\mathcal{K}) é definida como (\mathrm{dist}_{\mathcal{K}}(x)=\min{|x-x_0|:x_0\in\mathcal{K}})
- A projeção de (x), (\mathrm{proj}_{\mathcal{K}}(x)), é o conjunto de pontos em (\mathcal{K}) que atinge essa distância
- Se (\mathcal{K}) é um conjunto fechado, (x\notin\mathcal{K}) e a projeção é única, então o gradiente da função de distância quadrática é (x-\mathrm{proj}_{\mathcal{K}}(x))
- Como a função de distância (\mathrm{dist}_{\mathcal{K}}) não é diferenciável em todo lugar, introduz-se uma função de distância quadrática suavizada por (\sigma), usando softmin no lugar de
min - O gradiente da função de distância suavizada aponta para uma média ponderada dos pontos de (\mathcal{K}), com pesos determinados por (x)
Denoiser ideal e modelo de erro relativo
- O denoiser ideal (\epsilon^*) é aquele que minimiza exatamente a perda de treinamento para um dado (\sigma)
- Se os dados seguem uma distribuição uniforme discreta sobre um conjunto finito (\mathcal{K}), o denoiser ideal pode ser expresso em forma fechada
- O peso de cada ponto de dados é determinado pela distância entre (x_\sigma) e esse ponto
- Em datasets pequenos, ele pode ser calculado diretamente com
IdealDenoiser
- Em dados toy, o denoiser ideal aponta para a média dos dados quando (\sigma) é grande e para o ponto de dados mais próximo quando (\sigma) é pequeno
- O teorema central estabelece a relação (\frac{1}{2}\nabla_x \mathrm{dist}^2_{\mathcal{K}}(x,\sigma)=\sigma\epsilon^*(x,\sigma)) para todo (\sigma>0) e (x\in\mathbb{R}^n)
- O modelo de erro relativo usa a condição de que (x-\sigma\epsilon_\theta(x,\sigma)) aproxime bem (\mathrm{proj}_{\mathcal{K}}(x))
- Ele se aplica quando (\sqrt{n}\sigma) estima bem (\mathrm{dist}_{\mathcal{K}}(x)) dentro de um fator constante
- Assume-se que o erro é limitado por (\eta\mathrm{dist}_{\mathcal{K}}(x))
- Em baixo ruído, sob a manifold hypothesis, a maior parte do ruído adicional é ortogonal à variedade dos dados, então o denoising aproxima uma projeção
- Em alto ruído, se (\sigma) é maior que o diâmetro de (\mathcal{K}), até um denoiser que prevê a média ponderada dos dados apresenta pequeno erro relativo
- O CIFAR-10 tem tamanho que ainda permite calcular o denoiser ideal, e os experimentos mostram pequeno erro relativo entre a projeção exata ao longo da trajetória de amostragem e a saída do denoiser ideal
Amostragem: denoising iterativo e DDIM
- Com um denoiser treinado, é possível prever (x_0) a partir de (x_t) com ruído e nível de ruído (\sigma_t) usando (\hat{x}0^t=x_t-\sigma_t\epsilon\theta(x_t,\sigma_t))
- O ponto inicial escolhe (\sigma_T) suficientemente grande em relação ao diâmetro de (\mathcal{K}), e (x_T) é amostrado independentemente de (N(0,\sigma_T)), ficando longe de (\mathcal{K})
- Em alto ruído, uma única chamada ao denoiser pode ter grande erro absoluto mesmo que o erro relativo seja pequeno, e a previsão do denoiser ideal tende a ficar próxima da média dos dados
- Por isso, a amostragem chama o denoiser repetidamente ao longo de um cronograma de (\sigma_t), produzindo a sequência (x_T,\ldots,x_0)
- A atualização (x_{t-1}=x_t-(\sigma_t-\sigma_{t-1})\epsilon_\theta(x_t,\sigma_t)) é equivalente ao algoritmo de amostragem DDIM determinístico após uma mudança de coordenadas
- A prova de equivalência com DDIM está no Apêndice A do artigo
DDIM como minimização de distância
- O DDIM pode ser interpretado como uma descida de gradiente aproximada sobre (f(x)=\frac{1}{2}\mathrm{dist}_{\mathcal{K}}(x)^2)
- O tamanho do passo é (1-\sigma_{t-1}/\sigma_t)
- (\nabla f(x_t)) é estimado por (\epsilon_\theta(x_t,\sigma_t))
- O cronograma de (\sigma_t) determina o número e o tamanho dos passos de gradiente durante a amostragem
- Se houver poucos passos, (\mathrm{dist}_{\mathcal{K}}(x_t)) pode não diminuir e a convergência pode falhar
- Muitos passos pequenos aumentam o número de avaliações do denoiser e, portanto, o custo computacional
- Um cronograma admissible é aquele em que, a cada iteração, (\sqrt{n}\sigma_t) permanece dentro de um fator constante de (\mathrm{dist}_{\mathcal{K}}(x_t))
- Uma sequência log-linear de (\sigma_t) decrescendo geometricamente é um cronograma admissible
- Segundo o teorema, se (\nabla\mathrm{dist}{\mathcal{K}}(x)) existe nos (x_t) gerados por DDIM e (\mathrm{dist}{\mathcal{K}}(x_T)=\sqrt{n}\sigma_T), então (x_t) é gerado por descida de gradiente sobre a função de distância quadrática e a relação (\mathrm{dist}_{\mathcal{K}}(x_t)/\sqrt{n}\approx\sigma_t) é mantida
- No exemplo toy, implementa-se um sampler DDIM de 20 passos por subamostragem do cronograma log-linear original; a maioria das amostras fica próxima dos dados originais, mas ainda há espaço para melhorias
Sampler aprimorado baseado em estimativa de gradiente
- Aproveitando o fato de que (\nabla\mathrm{dist}{\mathcal{K}}(x)) permanece invariante entre (x) e (\mathrm{proj}{\mathcal{K}}(x)), usa-se uma atualização que mistura a estimativa atual com a anterior
- A atualização (\bar{\epsilon}t=\gamma\epsilon\theta(x_t,\sigma_t)+(1-\gamma)\epsilon_\theta(x_{t+1},\sigma_{t+1})) corrige o erro do passo anterior com a estimativa atual
- Em amostras do modelo toy, esse método converge mais rápido que DDIM e gera amostras mais próximas dos dados originais
- Em comparação com DDIM, esse sampler pode ser interpretado como a adição de momentum, e a trajetória pode sofrer overshoot, mas também convergir mais rapidamente
- Adicionar ruído durante a geração melhora empiricamente a qualidade da amostragem
- Para manter o cronograma original de (\sigma_t), primeiro faz-se denoising até um (\sigma_{t'}) menor e depois adiciona-se novamente ruído (w_t\sim N(0,I))
- Quando (\mu=\frac{1}{2}), o DDPM sampler é recuperado exatamente
- A atualização completa (x_{t-1}=x_t-(\sigma_t-\sigma_{t'})\bar{\epsilon}_t+\eta w_t) generaliza três samplers
- DDIM:
gam=1, mu=0 - DDPM:
gam=1, mu=0.5 - Sampler por estimativa de gradiente:
gam=2, mu=0
- DDIM:
Modelos maiores e materiais de referência
- O código de treinamento anterior pode ser usado não só com dados toy, mas também para treinar modelos de difusão de imagem do zero
- O exemplo com FashionMNIST treina no dataset FashionMNIST e é apresentado como um exemplo que alcança a 2ª melhor pontuação em FID no leaderboard do Papers with Code
- O código de amostragem também pode ser usado sem modificações em modelos pré-treinados de latent diffusion
- O exemplo usa
ScheduleLDM(1000)eModelLatentDiffusion('stabilityai/stable-diffusion-2-1-base') - A condição de texto é
An astronaut riding a horse, e após amostragem com 50 passos de (\sigma), o latent é decodificado
- O exemplo usa
- O efeito do termo de momentum (\gamma) é mostrado em visualizações comparativas de geração de texto para imagem em alta resolução
- Materiais adicionais recomendados
- What are diffusion models: introdução a modelos de difusão sob a perspectiva de tempo discreto, revertendo um processo de Markov
- Generative modeling by estimating gradients of the data distribution: introdução a modelos de difusão sob a perspectiva de tempo contínuo, revertendo equações diferenciais estocásticas
- The annotated diffusion model: explicação detalhada de implementação de modelos de difusão em PyTorch
1 comentários
Comentários do Hacker News
Posso responder a perguntas, se houver.
Gostei especialmente da discussão sobre trajetórias, porque ela ajuda a motivar a compreensão de pontos em que muita gente tem dificuldade, como schedulers. Embora não seja tão completo quanto os textos do Song ou da Lilian, é muito mais acessível, então pretendo recomendá-lo a outras pessoas.
Como referência, um amigo escreveu há algum tempo uma implementação mínima de difusão que é um pouco mais “completa” do ponto de vista de DDPM e foi útil: https://github.com/VSehwag/minimal-diffusion/
Como alguém que já experimentou um pouco com o procedimento de amostragem no Stable Diffusion, também gostaria de ver uma comparação de tempo de convergência e número de etapas em relação ao DDIM. Fico curioso se há alguma relação entre momentum, convergência e erro. Por exemplo, seria bom ter uma comparação mostrando se um sampler com momentum em 16 etapas é quase equivalente ao DDIM em 20 etapas ± um termo de erro.
get_sigma_embeds(batches, sigma)não usa a primeira entrada. Fico curioso se a intenção era fazer broadcast desigmano formato(batches, 1).Ele entra muito mais a fundo nos detalhes matemáticos e ainda vem com uma implementação mínima, muito fácil de entender, com menos de 500 linhas.
Seria ótimo se isso também fosse estendido para a versão de transformers de difusão que move o Sora e outros modelos de geração de vídeo. Daria para combinar este texto com https://jaykmody.com/blog/gpt-from-scratch/ e criar um artigo introdutório “transformers de difusão do zero”.
Por outro lado, se você quiser se aprofundar de verdade, recomendo ler os trabalhos de Kingma, Gao, Ricky Tian Qi Chen e dos alunos de Max Welling (Tomczak é pós-doutorando, Hoogeboom etc.), além do trabalho de Aapo Hyvärinen, que é um herói pouco reconhecido. Um exemplo de um trabalho relativamente mais leve de Kingma & Gao, também relacionado ao artigo do SD3, está aqui: https://arxiv.org/abs/2303.00848
O ponto negativo é que a dependência de conhecer e entender pesquisas anteriores é grande, o que reduz a acessibilidade; mas também é difícil chamar isso de uma crítica significativa, porque é pesquisa, não material educativo para o público geral.
n_embd; o processo de difusão em si pode continuar igual.[1] https://yang-song.net/blog/2021/score/
[2] https://lilianweng.github.io/posts/2021-07-11-diffusion-mode...
Do nosso ponto de vista, o motivo pelo qual modelos de difusão são fáceis de treinar é que eles usam um objetivo de treinamento que prevê o gradiente de uma função de distância suavizada em vez de prever o gradiente da função de distância exata. A amostragem de modelos de difusão é parecida com executar várias etapas aproximadas de gradiente.
Para entender modelos de difusão mais a fundo, recomendo ler todos esses posts de blog e aprender as diferentes interpretações.
Ainda assim, a abordagem deste texto parece permitir experimentos mais interessantes, como análise de erro do denoiser.
[1] https://arxiv.org/pdf/2305.03486.pdf
Por exemplo, por que é difícil para geradores de imagem criar teclas de piano? Parece que, para produzir a estrutura em que teclas pretas alternam em grupos de duas e três, seria necessário representar melhor restrições de distância intermediárias.