- Quando um LLM genérico é excessivo para tarefas especializadas, fazer fine-tuning direto no Llama-2 pode melhorar ao mesmo tempo qualidade, custo e latência com um modelo menor e mais barato
- Após o fine-tuning, o Llama-2 13B teve aumento de precisão em representações funcionais do ViGGO de 58%→98%, em geração de SQL de 42%→89% e no GSM8k de 28%→47%
- Em tarefas nas quais o formato de saída é importante, como ViGGO e geração de SQL, modelos menores do Llama-2 superaram o GPT-4, mas em raciocínio matemático não chegaram ao nível do GPT-4
- Os experimentos foram feitos com scripts baseados em Ray Train, Ray Data, DeepSpeed e Accelerate; os modelos 7B e 13B foram treinados em 16xA10G, e o 70B em 32xA10G
- A chave do ganho de desempenho foi qualidade dos dados e pipeline de avaliação, mais do que o tamanho do modelo; é preciso comparar, por tarefa, o trade-off de custo e qualidade entre prompt engineering e fine-tuning
Efeito do fine-tuning em três tarefas
- Grandes modelos genéricos como GPT-4 e Claude-2 são úteis para prototipagem rápida, mas podem ser exagerados em custo e desempenho para demandas de escopo estreito, como resumo e classificação de tickets de suporte
- O experimento compara o quanto melhora ao ajustar modelos Llama-2 para três tarefas realistas com fine-tuning de parâmetros completos
- ViGGO: extração de representações funcionais a partir de texto não estruturado
- SQL-create-context: geração de SQL a partir de linguagem natural e contexto de
CREATE TABLE - GSM8k: resolução de problemas matemáticos de nível fundamental
- No caso do Llama-2 13B, a mudança de precisão foi a seguinte
- Representação funcional do ViGGO: 58% → 98%
- Geração de SQL: 42% → 89%
- GSM8k: 28% → 47%
- Em ViGGO e geração de SQL, modelos menores do Llama-2 tiveram resultados melhores que o GPT-4; em tarefas de raciocínio matemático como GSM8k, mesmo após o fine-tuning, eles não alcançaram o desempenho do GPT-4
Método de fine-tuning e infraestrutura de treinamento
- As três tarefas usaram fine-tuning padrão de parâmetros completos
- O treinamento foi feito com previsão do próximo token
- Todos os parâmetros do modelo foram alvo de atualização por gradiente
- Abordagens como LoRA ou congelamento de parte dos blocos transformer ficaram fora do escopo do experimento
- Os scripts do experimento foram construídos sobre Ray Train, Ray Data, DeepSpeed e Accelerate
- Há suporte para executar Llama-2 7B, 13B e 70B
- O TorchTrainer do Ray Train distribui o loop de treinamento entre vários processos worker e recursos de GPU
- O particionamento dos dados é tratado pelo Ray Train, e cada worker acessa seu fragmento atribuído com
session.get_dataset_shard("train")esession.get_dataset_shard("valid")
- O particionamento do modelo foi tratado com DeepSpeed ZeRO stage 3 e offloading do estado do otimizador
- Como os fragmentos do modelo ficam distribuídos entre vários workers, quando é necessário acessar o modelo completo, como ao salvar checkpoints, é preciso desembrulhá-lo com
accelerator.unwrap_model(model)
- Como os fragmentos do modelo ficam distribuídos entre vários workers, quando é necessário acessar o modelo completo, como ao salvar checkpoints, é preciso desembrulhá-lo com
- Os recursos computacionais foram os seguintes
- 7B e 13B: 16xA10G
- 70B: 32xA10G, em 4 instâncias
g5.48xlarge - Com Ray, o fine-tuning de parâmetros completos não exige necessariamente A100
- O treinamento foi executado por até 10 epochs, e foi escolhido o checkpoint com menor perplexity no conjunto de validação
Fixando a estrutura de entrada e saída com tokens especiais
- Os dados de fine-tuning representam a estrutura da tarefa com tokens especiais em vez de prompts em forma de instrução
- Exemplo:
<START_Q>{question}<END_Q><START_A>{answer}<END_A>
- Exemplo:
- Os tokens especiais ajudam o modelo a distinguir as seções de entrada e saída e a aprender com clareza onde a geração deve parar
- No exemplo,
<END_A>é definido como stopping token para interromper a saída ao concluir a tarefa
- No exemplo,
- O tokenizer do Llama produz, por padrão, 32.000 IDs de token
- Ao adicionar quatro tokens especiais, ele passa a produzir 32.004 IDs
<START_Q>recebe um novo ID como 32000,<END_Q>recebe 32001, e assim por diante
- O script adiciona os tokens especiais com
tokenizer.add_tokens(special_tokens, special_tokens=True)e cria novos parâmetros treináveis commodel.resize_token_embeddings(len(tokenizer))
ViGGO: convertendo texto não estruturado em representação funcional
- O ViGGO é originalmente um dataset em inglês que transforma representações funcionais baseadas em pares atributo-valor em texto natural; no experimento, a direção foi invertida para converter texto não estruturado em representação funcional estruturada
- O domínio é o de opiniões sobre videogames
- A representação resultante pode ser usada para indexação e aplicações posteriores
- O modelo precisa gerar a função e os valores de atributos adequados à frase
- Entre as funções candidatas estão
inform,request,give_opinion,confirm,verify_attribute,suggest,request_explanation,recommenderequest_attribute - Entre os atributos candidatos estão
name,release_year,esrb,genres,platforms,available_on_steam,has_linux_release,has_mac_release,specifier,rating,player_perspective,has_multiplayer,developer,exp_release_dateetc.
- Entre as funções candidatas estão
- Para a entrada de exemplo
What's a really fast-paced game with multiplayer that you like to play?, a saída esperada érequest(has_multiplayer[yes], specifier[fast-paced]) - Modelos genéricos não seguiam bem o formato de saída desejado, e havia também o problema de o tempo de processamento da entrada ser maior que o de geração da saída por causa do contexto longo de entrada
- Essa tarefa depende mais de reconhecimento de padrões e compreensão básica da linguagem do que de raciocínio lógico complexo
- É uma grounded task, em que todos os fatos necessários estão contidos na entrada
- O fato de few-shot prompting ajudar é visto como um sinal de que modelos menores do Llama-2 também podem melhorar com fine-tuning
Avaliação e resultados do ViGGO
- A avaliação não usa apenas correspondência exata de caracteres
- Verifica se a função gerada está correta
- Verifica se os tipos de atributo estão corretos
- Verifica se os atributos dentro da função seguem uma ordem de prioridade definida
- Para modelos instruction-following como GPT e Llama-2-chat, a regra de ordenação dos atributos era explicitada no prompt, então a avaliação exigia o cumprimento dessa regra
- Para acelerar a avaliação, foram usados em conjunto a batch inference API do Ray e o Aviary da Anyscale
- Isso conecta a geração do LLM ao pós-processamento e distribui o trabalho em várias máquinas
- Os modelos 7B e 13B tiveram grande ganho de precisão após o fine-tuning
- O GPT-4 sofreu uma queda acentuada de precisão quando a prioridade dos atributos foi incluída na avaliação
- Os modelos ajustados sempre seguiram a prioridade, e a precisão não mudou mesmo com essa restrição adicional
- Os resultados do ViGGO mostram que o fine-tuning pode ser um meio estável e eficiente para tarefas que exigem formato estruturado
- Não se trata apenas de ajustar regex simples ou formato JSON, mas de decidir quais argumentos incluir e ainda respeitar a ordem dos argumentos incluídos
- Como os resultados foram obtidos com modelos 7B e 13B, o custo de serving pode ser menor do que chamar um endpoint de GPT-4
Geração de SQL: criando consultas a partir de linguagem natural e contexto de tabela
- A tarefa de geração de SQL recebe uma consulta em linguagem natural e uma instrução SQL
CREATE TABLEcomo entrada para produzir uma consulta SQL executável - O dataset usado, b-mc2/sql-create-context, é um dataset do Hugging Face que combina WikiSQL e Spider
- Cada ponto de dado é composto por uma consulta em linguagem natural, uma instrução SQL
CREATE TABLEe a consulta SQL correspondente - No total, são 78.577 pontos de dado
- Cada ponto de dado é composto por uma consulta em linguagem natural, uma instrução SQL
- O dataset tinha problemas nas respostas SQL corretas
- Em
CREATE TABLE, atributos inteiros frequentemente apareciam comoVARCHAR, mas na consulta SQL eram tratados como inteiros - Todas as consultas SQL que assumiam atributos inteiros foram removidas, reduzindo o dataset de cerca de 70k para 45k
- Em
- Essa tarefa também é adequada para fine-tuning por transformar linguagem natural em uma representação estruturada em SQL
- Diferentemente do ViGGO, pode haver várias consultas SQL corretas que geram o resultado certo, o que a torna mais ambígua
Avaliação e resultados em SQL
- A avaliação de geração de SQL não é adequada para comparação simples de strings
- A comparação caractere a caractere pode gerar muitos falsos negativos
- A comparação por AST também pode ser sensível a elementos como a ordem dos nomes de variáveis
- O método mais confiável é executar o código em um dataset sintético e comparar se a saída é a mesma
- No experimento, foi usado o endpoint GPT-3.5 da OpenAI para gerar tabelas sintéticas de teste unitário para centenas de exemplos
- O GPT-3.5 criava tabelas sintéticas com 10 pontos de dados a partir da pergunta, do schema da tabela e da resposta correta
- Com
sqlglot.executor.execute, eram executadas tanto a SQL correta quanto a SQL do modelo, e os resultados eram comparados
- Para verificar a qualidade das tabelas geradas pelo GPT-3.5, a SQL correta era executada primeiro
- Se a tabela de resultado ficasse vazia ou tivesse o mesmo tamanho da tabela original, o exemplo era descartado
- Nesse processo, cerca de 50% das tabelas sintéticas geradas pelo GPT foram filtradas
- Os modelos Llama-2 7B e 13B ajustados tiveram desempenho superior ao 70B-chat e ao GPT-4
- Um erro comum dos modelos Llama chat era não colocar a SQL de forma consistente dentro das tags
<SQL>, contrariando a instrução do prompt - Esse problema era mais frequente nos modelos chat 7B e 13B do que no 70B
- Um erro comum dos modelos Llama chat era não colocar a SQL de forma consistente dentro das tags
- Algumas consultas em linguagem natural no dataset SQL não estavam em inglês perfeito, e esse ruído pode ter afetado os resultados do GPT-4
- Os modelos ajustados se adaptaram rapidamente até mesmo a esses hábitos peculiares do dataset
GSM8k: raciocínio matemático mais difícil do que aprender estrutura
- O GSM8k é um benchmark acadêmico padrão para avaliar raciocínio matemático e capacidade de compreensão
- Se nas duas tarefas anteriores o foco era principalmente aprender estrutura, no GSM8k o objetivo é verificar o quanto o modelo pode melhorar seu processo de raciocínio para resolver problemas matemáticos
- Um exemplo de problema pergunta o total vendido quando 48 unidades foram vendidas em abril e metade disso em maio; a resposta correta termina no formato
#### 72, junto com os cálculos intermediários - Os LLMs atuais, em vez de calcularem apenas a resposta final internamente e fornecê-la de imediato, precisam gerar parte do processo de pensamento na saída para que a geração dos tokens seguintes possa se basear nesse processo lógico
- Essa tarefa exige não só cálculo simples, mas também uma chain of thought lógica que vá das premissas às conclusões intermediárias e então à resposta final
Método de avaliação e baselines do GSM8k
- A avaliação precisa de um método confiável para extrair a resposta final do output do modelo
- Modelos de linguagem genéricos podem não seguir de forma consistente o formato de saída desejado, o que dificulta a avaliação automática
- Para isso, foi usada a OpenAI function calling API
- O
gpt-3.5-turbo-0613chamava a funçãoreport_answerpara extrair a resposta inteira final das gerações de outros modelos - Por exemplo, mesmo que o modelo respondesse “The answer is four”, isso podia ser interpretado como
4
- Esse método foi validado testando as respostas corretas do dataset, mas tem a desvantagem de adicionar custo de tokens da OpenAI à avaliação
- Os modelos ajustados aprendem rapidamente o padrão de resposta desejado e, mesmo quando erram, mantêm uma estrutura de saída previsível
- A avaliação dos modelos ajustados foi feita com regex
#### {answer}, evitando pós-processamento por endpoint da OpenAI
- A avaliação dos modelos ajustados foi feita com regex
- Os baselines foram os seguintes
- Resultados de 8-shot prompting de modelos base pré-treinados divulgados em artigos
- Vários templates com prompt engineering para variantes Llama-2 ajustadas para chat, treinadas pela Meta com RLHF para atuar como assistente genérico
Resultados do GSM8k e fine-tuning em duas etapas
- O fine-tuning dos modelos base elevou de forma consistente o desempenho no GSM8k, mas nem sempre produziu resultados muito melhores que os dos modelos ajustados para chat
- Os modelos chat tinham precisão maior que os modelos base, possivelmente por já terem sido expostos a exemplos matemáticos durante o chat-tuning
- Usar prompting sobre os modelos ajustados nem sempre produz resultados melhores que nos modelos base
- Por exemplo, o Llama-2-70B-chat pode ficar abaixo de um modelo base com prompt de 8 exemplos
- Os modelos ajustados, porém, foram consistentemente melhores que os modelos base com 8-shot prompting
- Em termos de custo de serving, os modelos ajustados podem levar vantagem
- Abordagens baseadas em prompt adicionam custo de tokens do prompt a cada requisição
- Nos modelos ajustados, na prática o custo reflete apenas a quantidade de tokens da pergunta
- Como os dados de treinamento do GSM8k têm apenas cerca de 8k exemplos, considerou-se difícil extrair todo o potencial do Llama-13B apenas com esse conjunto
- Uma abordagem em duas etapas, primeiro ajustando o modelo base Llama-13B em MathQA e depois novamente em GSM8k, trouxe melhora adicional
- O fine-tuning usando apenas GSM8k melhorou 10 p.p. em relação ao modelo base
- O fine-tuning em duas etapas com MathQA seguido de GSM8k trouxe mais 10 p.p. sobre o primeiro ajuste, totalizando 20 p.p. sobre o modelo base
- O MathQA é composto por 30.000 pares de perguntas e respostas, mas tem mais ruído e estrutura diferente do GSM8k
- A qualidade das respostas é inferior, e a resposta final tem formato multiple choice
- Ainda assim, o fine-tuning em duas etapas usando MathQA foi eficaz para melhorar o resultado final no GSM8k
Critérios a observar na aplicação prática
- Modelos fechados como GPT-4 e Claude-2 são fortes em prototipagem e validação inicial de valor, mas nem sempre bastam para operar apps de LLM em produção
- O fine-tuning de LLMs para niche tasks pode ter valor não só em privacidade, mas também em latência, custo e qualidade
- Nos exemplos de ViGGO e SQL, a qualidade também foi melhor que a do GPT-4
- No fine-tuning, o foco importante não está nos detalhes de implementação da infraestrutura, mas na coleta de dados e na construção do pipeline de avaliação
- O pipeline de avaliação serve de base para comparar, conforme as necessidades do negócio, os trade-offs entre diferentes soluções
- Os experimentos foram realizados com a plataforma de fine-tuning e serving da Anyscale e com o Anyscale Endpoints
- O mesmo processo foi montado sobre Ray para que a solução de fine-tuning e serving da Anyscale possa ser repetida com dados próprios e em nuvem própria
1 comentários
Comentários do Hacker News
Algumas semanas atrás, numa live de programação, cobri bastante como fazer fine-tuning do Llama 2 com um dataset próprio, usando uma única GPU no Colab.
No meu caso, o dataset era o meu código.
Fine-tuning Llama stream: https://www.youtube.com/watch?v=TYgtG2Th6fI&t=2282s
Também tenho mais algumas sessões de fine-tuning com QLoRA, nas quais explico os conceitos do ponto de vista de um engenheiro de software com 8 anos de experiência que recentemente migrou para machine learning e aprendeu por conta própria.
QloRa fine-tuning stream: https://www.youtube.com/watch?v=LitybCiLhSc&t=4584s
Tento explicar da forma mais simples possível como abordo isso em projetos pessoais e na startup baseada em IA em que estou trabalhando atualmente. Uma série sobre fazer fine-tuning do menor LLM para desenvolvimento web também parece ter sido bem recebida; venho fazendo streaming há cerca de um mês e pretendo publicar muito mais daqui para frente.
Também não entendo bem a abordagem de dividir modelos ajustados. É preciso ter um LLM de Terraform, um LLM de SQL e um LLM de Python separados, ou basta ter um único LLM de “código”?
São necessários tantos detalhes de implementação que a acessibilidade fica baixa, a menos que haja um caso de uso realmente relevante. Acho que o privateGPT vai chegar devagar a esse ponto.
É uma parte que muitos outros tutoriais pulam. Tenho curiosidade especialmente sobre como preparar os dados para objetivos diferentes, como segurança e precisão.
Estou enfrentando o mesmo problema com o Llama 2. É quase impossível fazê-lo imprimir somente o texto que eu quero; ele sempre acrescenta algo antes ou depois da resposta.
Gostaria de saber se existe alguma técnica de prompt que corrija isso.
O airoboros oferece suporte a um token PLAINFORMAT que faz a saída ser apenas código, evitando crases, explicações etc.
https://huggingface.co/TheBloke/airoboros-l2-70B-GPT4-2.0-GG...
Para garantir, o melhor é fazer fine-tuning com um dataset pequeno, de cerca de mil exemplos, e melhorar a partir daí.
Meu caso de uso era uma tarefa simples de extração/síntese de informações a partir de texto, mais do que escrita criativa. O modelo base pode não ser adequado para todas as tarefas.
contentou de JSON.Se for JSON, dá para identificar o início e o fim, então é só remover o que estiver fora do JSON.
Fico feliz em ver um texto como este. Houve tanta discussão online sobre customização de modelos que este artigo consegue eliminar bastante ruído.
Também gostei da metodologia de avaliação, e o texto parece bem escrito.
É estranho que LoRA e treinamento quantizado não sejam tratados com mais seriedade. São muito mais baratos, levam menos tempo e há bastante evidência de que funcionam bem.
Não acho que devam ser deixados como uma opção extra para tentar mais tarde.
Fico feliz em ver que uma tarefa parecida com NER teve o melhor desempenho. Eu estava prestes a fazer testes semelhantes para comparar com um modelo BERT ajustado.
Tenho curiosidade sobre qual foi o custo de treinamento dessa tarefa.
O tamanho de bloco poderia ser reduzido, mas era mais fácil não mudar o código, então deixamos assim. O 7B levou cerca de 15 minutos por época em 16xA10G, e o 13B levou cerca de 25 minutos. Portanto, o custo on-demand por época é de cerca de US$ 7,2 para o 7B e US$ 12 para o 13B. Esses valores consideram apenas o tempo usado no treinamento e não incluem o tempo de inicialização/encerramento do cluster.
O texto diz que usaram 16xA10G para o 7B e o 13B, e 32xA10G para o 70B, distribuídas em quatro instâncias g5.48xlarge. Com Ray, não é necessário obter A100s para fazer fine-tuning de todos os parâmetros desses modelos, e o mesmo processo é repetido para cada tarefa. No dataset GSM8k, eles mostram uma execução de exemplo com comprimento de contexto 512 e 3,7 milhões de tokens efetivos por época.
Eles dizem que treinaram por até 10 épocas e escolheram o checkpoint com a menor perplexidade no conjunto de validação.
Uma dificuldade é que, para criar um dataset customizado grande o suficiente, você precisa de algo como um pequeno exército de pessoas ou de um modelo existente muito forte.
No fim, há uma boa chance de ter que usar a OpenAI, mas gerar material de treinamento para outro modelo com a OpenAI viola os termos. Tenho curiosidade se isso já chegou a virar processo. As pessoas simplesmente consideram injusto e ignoram?
Tenho visto mais exemplos de NER ultimamente, e me pergunto por que não usar spaCy para essas tarefas.
Trabalho na Anyscale.
Como este blog parece ter recebido uma boa atenção, planejamos incluí-lo no Ray Summit: https://raysummit.anyscale.com/agenda
Se tiverem ideias sobre que tipos de conteúdo gostariam de ver mais no Ray Summit, seria ótimo saber.
Dizem que, para 3,5 milhões de tokens, o 7B leva cerca de 14 minutos por época, e o 13B, cerca de 26 minutos por época.
Para o 7B e o 13B, parece ser necessário no mínimo 1xg5.16xlarge como nó head e 15xg5.4xlarge como nós worker; fico curioso sobre quanto isso custaria na AWS.
Rodando em us-east-1, dá para considerar cerca de US$ 30 por hora.
https://instances.vantage.sh/?selected=g5.16xlarge,g5.4xlarg...
Tenho curiosidade se dá para fazer fine-tuning local do Llama-2 em um M1 Ultra 64 GB. A maioria dos materiais usa a nuvem ou Nvidia CUDA no Linux, então seria bom ter alguma referência.
Para treinamento, pretendo comprar alguns créditos do RunPod, e acho que deve ser possível com algumas dezenas de dólares.