1 pontos por GN⁺ 2023-08-13 | 1 comentários | Compartilhar no WhatsApp
  • Quando um LLM genérico é excessivo para tarefas especializadas, fazer fine-tuning direto no Llama-2 pode melhorar ao mesmo tempo qualidade, custo e latência com um modelo menor e mais barato
  • Após o fine-tuning, o Llama-2 13B teve aumento de precisão em representações funcionais do ViGGO de 58%→98%, em geração de SQL de 42%→89% e no GSM8k de 28%→47%
  • Em tarefas nas quais o formato de saída é importante, como ViGGO e geração de SQL, modelos menores do Llama-2 superaram o GPT-4, mas em raciocínio matemático não chegaram ao nível do GPT-4
  • Os experimentos foram feitos com scripts baseados em Ray Train, Ray Data, DeepSpeed e Accelerate; os modelos 7B e 13B foram treinados em 16xA10G, e o 70B em 32xA10G
  • A chave do ganho de desempenho foi qualidade dos dados e pipeline de avaliação, mais do que o tamanho do modelo; é preciso comparar, por tarefa, o trade-off de custo e qualidade entre prompt engineering e fine-tuning

Efeito do fine-tuning em três tarefas

  • Grandes modelos genéricos como GPT-4 e Claude-2 são úteis para prototipagem rápida, mas podem ser exagerados em custo e desempenho para demandas de escopo estreito, como resumo e classificação de tickets de suporte
  • O experimento compara o quanto melhora ao ajustar modelos Llama-2 para três tarefas realistas com fine-tuning de parâmetros completos
    • ViGGO: extração de representações funcionais a partir de texto não estruturado
    • SQL-create-context: geração de SQL a partir de linguagem natural e contexto de CREATE TABLE
    • GSM8k: resolução de problemas matemáticos de nível fundamental
  • No caso do Llama-2 13B, a mudança de precisão foi a seguinte
    • Representação funcional do ViGGO: 58% → 98%
    • Geração de SQL: 42% → 89%
    • GSM8k: 28% → 47%
  • Em ViGGO e geração de SQL, modelos menores do Llama-2 tiveram resultados melhores que o GPT-4; em tarefas de raciocínio matemático como GSM8k, mesmo após o fine-tuning, eles não alcançaram o desempenho do GPT-4

Método de fine-tuning e infraestrutura de treinamento

  • As três tarefas usaram fine-tuning padrão de parâmetros completos
    • O treinamento foi feito com previsão do próximo token
    • Todos os parâmetros do modelo foram alvo de atualização por gradiente
    • Abordagens como LoRA ou congelamento de parte dos blocos transformer ficaram fora do escopo do experimento
  • Os scripts do experimento foram construídos sobre Ray Train, Ray Data, DeepSpeed e Accelerate
    • Há suporte para executar Llama-2 7B, 13B e 70B
    • O TorchTrainer do Ray Train distribui o loop de treinamento entre vários processos worker e recursos de GPU
    • O particionamento dos dados é tratado pelo Ray Train, e cada worker acessa seu fragmento atribuído com session.get_dataset_shard("train") e session.get_dataset_shard("valid")
  • O particionamento do modelo foi tratado com DeepSpeed ZeRO stage 3 e offloading do estado do otimizador
    • Como os fragmentos do modelo ficam distribuídos entre vários workers, quando é necessário acessar o modelo completo, como ao salvar checkpoints, é preciso desembrulhá-lo com accelerator.unwrap_model(model)
  • Os recursos computacionais foram os seguintes
    • 7B e 13B: 16xA10G
    • 70B: 32xA10G, em 4 instâncias g5.48xlarge
    • Com Ray, o fine-tuning de parâmetros completos não exige necessariamente A100
  • O treinamento foi executado por até 10 epochs, e foi escolhido o checkpoint com menor perplexity no conjunto de validação

Fixando a estrutura de entrada e saída com tokens especiais

  • Os dados de fine-tuning representam a estrutura da tarefa com tokens especiais em vez de prompts em forma de instrução
    • Exemplo: <START_Q>{question}<END_Q><START_A>{answer}<END_A>
  • Os tokens especiais ajudam o modelo a distinguir as seções de entrada e saída e a aprender com clareza onde a geração deve parar
    • No exemplo, <END_A> é definido como stopping token para interromper a saída ao concluir a tarefa
  • O tokenizer do Llama produz, por padrão, 32.000 IDs de token
    • Ao adicionar quatro tokens especiais, ele passa a produzir 32.004 IDs
    • <START_Q> recebe um novo ID como 32000, <END_Q> recebe 32001, e assim por diante
  • O script adiciona os tokens especiais com tokenizer.add_tokens(special_tokens, special_tokens=True) e cria novos parâmetros treináveis com model.resize_token_embeddings(len(tokenizer))

ViGGO: convertendo texto não estruturado em representação funcional

  • O ViGGO é originalmente um dataset em inglês que transforma representações funcionais baseadas em pares atributo-valor em texto natural; no experimento, a direção foi invertida para converter texto não estruturado em representação funcional estruturada
    • O domínio é o de opiniões sobre videogames
    • A representação resultante pode ser usada para indexação e aplicações posteriores
  • O modelo precisa gerar a função e os valores de atributos adequados à frase
    • Entre as funções candidatas estão inform, request, give_opinion, confirm, verify_attribute, suggest, request_explanation, recommend e request_attribute
    • Entre os atributos candidatos estão name, release_year, esrb, genres, platforms, available_on_steam, has_linux_release, has_mac_release, specifier, rating, player_perspective, has_multiplayer, developer, exp_release_date etc.
  • Para a entrada de exemplo What's a really fast-paced game with multiplayer that you like to play?, a saída esperada é request(has_multiplayer[yes], specifier[fast-paced])
  • Modelos genéricos não seguiam bem o formato de saída desejado, e havia também o problema de o tempo de processamento da entrada ser maior que o de geração da saída por causa do contexto longo de entrada
  • Essa tarefa depende mais de reconhecimento de padrões e compreensão básica da linguagem do que de raciocínio lógico complexo
    • É uma grounded task, em que todos os fatos necessários estão contidos na entrada
    • O fato de few-shot prompting ajudar é visto como um sinal de que modelos menores do Llama-2 também podem melhorar com fine-tuning

Avaliação e resultados do ViGGO

  • A avaliação não usa apenas correspondência exata de caracteres
    • Verifica se a função gerada está correta
    • Verifica se os tipos de atributo estão corretos
    • Verifica se os atributos dentro da função seguem uma ordem de prioridade definida
  • Para modelos instruction-following como GPT e Llama-2-chat, a regra de ordenação dos atributos era explicitada no prompt, então a avaliação exigia o cumprimento dessa regra
  • Para acelerar a avaliação, foram usados em conjunto a batch inference API do Ray e o Aviary da Anyscale
    • Isso conecta a geração do LLM ao pós-processamento e distribui o trabalho em várias máquinas
  • Os modelos 7B e 13B tiveram grande ganho de precisão após o fine-tuning
    • O GPT-4 sofreu uma queda acentuada de precisão quando a prioridade dos atributos foi incluída na avaliação
    • Os modelos ajustados sempre seguiram a prioridade, e a precisão não mudou mesmo com essa restrição adicional
  • Os resultados do ViGGO mostram que o fine-tuning pode ser um meio estável e eficiente para tarefas que exigem formato estruturado
    • Não se trata apenas de ajustar regex simples ou formato JSON, mas de decidir quais argumentos incluir e ainda respeitar a ordem dos argumentos incluídos
    • Como os resultados foram obtidos com modelos 7B e 13B, o custo de serving pode ser menor do que chamar um endpoint de GPT-4

Geração de SQL: criando consultas a partir de linguagem natural e contexto de tabela

  • A tarefa de geração de SQL recebe uma consulta em linguagem natural e uma instrução SQL CREATE TABLE como entrada para produzir uma consulta SQL executável
  • O dataset usado, b-mc2/sql-create-context, é um dataset do Hugging Face que combina WikiSQL e Spider
    • Cada ponto de dado é composto por uma consulta em linguagem natural, uma instrução SQL CREATE TABLE e a consulta SQL correspondente
    • No total, são 78.577 pontos de dado
  • O dataset tinha problemas nas respostas SQL corretas
    • Em CREATE TABLE, atributos inteiros frequentemente apareciam como VARCHAR, mas na consulta SQL eram tratados como inteiros
    • Todas as consultas SQL que assumiam atributos inteiros foram removidas, reduzindo o dataset de cerca de 70k para 45k
  • Essa tarefa também é adequada para fine-tuning por transformar linguagem natural em uma representação estruturada em SQL
    • Diferentemente do ViGGO, pode haver várias consultas SQL corretas que geram o resultado certo, o que a torna mais ambígua

Avaliação e resultados em SQL

  • A avaliação de geração de SQL não é adequada para comparação simples de strings
    • A comparação caractere a caractere pode gerar muitos falsos negativos
    • A comparação por AST também pode ser sensível a elementos como a ordem dos nomes de variáveis
    • O método mais confiável é executar o código em um dataset sintético e comparar se a saída é a mesma
  • No experimento, foi usado o endpoint GPT-3.5 da OpenAI para gerar tabelas sintéticas de teste unitário para centenas de exemplos
    • O GPT-3.5 criava tabelas sintéticas com 10 pontos de dados a partir da pergunta, do schema da tabela e da resposta correta
    • Com sqlglot.executor.execute, eram executadas tanto a SQL correta quanto a SQL do modelo, e os resultados eram comparados
  • Para verificar a qualidade das tabelas geradas pelo GPT-3.5, a SQL correta era executada primeiro
    • Se a tabela de resultado ficasse vazia ou tivesse o mesmo tamanho da tabela original, o exemplo era descartado
    • Nesse processo, cerca de 50% das tabelas sintéticas geradas pelo GPT foram filtradas
  • Os modelos Llama-2 7B e 13B ajustados tiveram desempenho superior ao 70B-chat e ao GPT-4
    • Um erro comum dos modelos Llama chat era não colocar a SQL de forma consistente dentro das tags <SQL>, contrariando a instrução do prompt
    • Esse problema era mais frequente nos modelos chat 7B e 13B do que no 70B
  • Algumas consultas em linguagem natural no dataset SQL não estavam em inglês perfeito, e esse ruído pode ter afetado os resultados do GPT-4
    • Os modelos ajustados se adaptaram rapidamente até mesmo a esses hábitos peculiares do dataset

GSM8k: raciocínio matemático mais difícil do que aprender estrutura

  • O GSM8k é um benchmark acadêmico padrão para avaliar raciocínio matemático e capacidade de compreensão
  • Se nas duas tarefas anteriores o foco era principalmente aprender estrutura, no GSM8k o objetivo é verificar o quanto o modelo pode melhorar seu processo de raciocínio para resolver problemas matemáticos
  • Um exemplo de problema pergunta o total vendido quando 48 unidades foram vendidas em abril e metade disso em maio; a resposta correta termina no formato #### 72, junto com os cálculos intermediários
  • Os LLMs atuais, em vez de calcularem apenas a resposta final internamente e fornecê-la de imediato, precisam gerar parte do processo de pensamento na saída para que a geração dos tokens seguintes possa se basear nesse processo lógico
  • Essa tarefa exige não só cálculo simples, mas também uma chain of thought lógica que vá das premissas às conclusões intermediárias e então à resposta final

Método de avaliação e baselines do GSM8k

  • A avaliação precisa de um método confiável para extrair a resposta final do output do modelo
  • Modelos de linguagem genéricos podem não seguir de forma consistente o formato de saída desejado, o que dificulta a avaliação automática
    • Para isso, foi usada a OpenAI function calling API
    • O gpt-3.5-turbo-0613 chamava a função report_answer para extrair a resposta inteira final das gerações de outros modelos
    • Por exemplo, mesmo que o modelo respondesse “The answer is four”, isso podia ser interpretado como 4
  • Esse método foi validado testando as respostas corretas do dataset, mas tem a desvantagem de adicionar custo de tokens da OpenAI à avaliação
  • Os modelos ajustados aprendem rapidamente o padrão de resposta desejado e, mesmo quando erram, mantêm uma estrutura de saída previsível
    • A avaliação dos modelos ajustados foi feita com regex #### {answer}, evitando pós-processamento por endpoint da OpenAI
  • Os baselines foram os seguintes
    • Resultados de 8-shot prompting de modelos base pré-treinados divulgados em artigos
    • Vários templates com prompt engineering para variantes Llama-2 ajustadas para chat, treinadas pela Meta com RLHF para atuar como assistente genérico

Resultados do GSM8k e fine-tuning em duas etapas

  • O fine-tuning dos modelos base elevou de forma consistente o desempenho no GSM8k, mas nem sempre produziu resultados muito melhores que os dos modelos ajustados para chat
    • Os modelos chat tinham precisão maior que os modelos base, possivelmente por já terem sido expostos a exemplos matemáticos durante o chat-tuning
  • Usar prompting sobre os modelos ajustados nem sempre produz resultados melhores que nos modelos base
    • Por exemplo, o Llama-2-70B-chat pode ficar abaixo de um modelo base com prompt de 8 exemplos
    • Os modelos ajustados, porém, foram consistentemente melhores que os modelos base com 8-shot prompting
  • Em termos de custo de serving, os modelos ajustados podem levar vantagem
    • Abordagens baseadas em prompt adicionam custo de tokens do prompt a cada requisição
    • Nos modelos ajustados, na prática o custo reflete apenas a quantidade de tokens da pergunta
  • Como os dados de treinamento do GSM8k têm apenas cerca de 8k exemplos, considerou-se difícil extrair todo o potencial do Llama-13B apenas com esse conjunto
  • Uma abordagem em duas etapas, primeiro ajustando o modelo base Llama-13B em MathQA e depois novamente em GSM8k, trouxe melhora adicional
    • O fine-tuning usando apenas GSM8k melhorou 10 p.p. em relação ao modelo base
    • O fine-tuning em duas etapas com MathQA seguido de GSM8k trouxe mais 10 p.p. sobre o primeiro ajuste, totalizando 20 p.p. sobre o modelo base
  • O MathQA é composto por 30.000 pares de perguntas e respostas, mas tem mais ruído e estrutura diferente do GSM8k
    • A qualidade das respostas é inferior, e a resposta final tem formato multiple choice
    • Ainda assim, o fine-tuning em duas etapas usando MathQA foi eficaz para melhorar o resultado final no GSM8k

Critérios a observar na aplicação prática

  • Modelos fechados como GPT-4 e Claude-2 são fortes em prototipagem e validação inicial de valor, mas nem sempre bastam para operar apps de LLM em produção
  • O fine-tuning de LLMs para niche tasks pode ter valor não só em privacidade, mas também em latência, custo e qualidade
    • Nos exemplos de ViGGO e SQL, a qualidade também foi melhor que a do GPT-4
  • No fine-tuning, o foco importante não está nos detalhes de implementação da infraestrutura, mas na coleta de dados e na construção do pipeline de avaliação
    • O pipeline de avaliação serve de base para comparar, conforme as necessidades do negócio, os trade-offs entre diferentes soluções
  • Os experimentos foram realizados com a plataforma de fine-tuning e serving da Anyscale e com o Anyscale Endpoints
  • O mesmo processo foi montado sobre Ray para que a solução de fine-tuning e serving da Anyscale possa ser repetida com dados próprios e em nuvem própria

1 comentários

 
GN⁺ 2023-08-13
Comentários do Hacker News
  • Algumas semanas atrás, numa live de programação, cobri bastante como fazer fine-tuning do Llama 2 com um dataset próprio, usando uma única GPU no Colab.
    No meu caso, o dataset era o meu código.
    Fine-tuning Llama stream: https://www.youtube.com/watch?v=TYgtG2Th6fI&t=2282s
    Também tenho mais algumas sessões de fine-tuning com QLoRA, nas quais explico os conceitos do ponto de vista de um engenheiro de software com 8 anos de experiência que recentemente migrou para machine learning e aprendeu por conta própria.
    QloRa fine-tuning stream: https://www.youtube.com/watch?v=LitybCiLhSc&t=4584s
    Tento explicar da forma mais simples possível como abordo isso em projetos pessoais e na startup baseada em IA em que estou trabalhando atualmente. Uma série sobre fazer fine-tuning do menor LLM para desenvolvimento web também parece ter sido bem recebida; venho fazendo streaming há cerca de um mês e pretendo publicar muito mais daqui para frente.

    • Tenho curiosidade sobre os critérios gerais para decidir quando faz sentido usar RAG ou fine-tuning.
      Também não entendo bem a abordagem de dividir modelos ajustados. É preciso ter um LLM de Terraform, um LLM de SQL e um LLM de Python separados, ou basta ter um único LLM de “código”?
    • Há uma necessidade real de um app/módulo/biblioteca simples no nível de “coloque os materiais-fonte neste diretório, aperte um botão e converse com aquele conteúdo”.
      São necessários tantos detalhes de implementação que a acessibilidade fica baixa, a menos que haja um caso de uso realmente relevante. Acho que o privateGPT vai chegar devagar a esse ponto.
    • Gostei, e seria ótimo se você fizesse também uma série sobre preparação de datasets customizados para fine-tuning.
      É uma parte que muitos outros tutoriais pulam. Tenho curiosidade especialmente sobre como preparar os dados para objetivos diferentes, como segurança e precisão.
    • Dá para fazer com uma GPU? Tenho curiosidade se é realista mesmo com uma 3060.
  • Estou enfrentando o mesmo problema com o Llama 2. É quase impossível fazê-lo imprimir somente o texto que eu quero; ele sempre acrescenta algo antes ou depois da resposta.
    Gostaria de saber se existe alguma técnica de prompt que corrija isso.

    • É melhor usar um modelo melhor.
      O airoboros oferece suporte a um token PLAINFORMAT que faz a saída ser apenas código, evitando crases, explicações etc.
      https://huggingface.co/TheBloke/airoboros-l2-70B-GPT4-2.0-GG...
    • Os modelos Llama-2-chat foram excessivamente ajustados dessa forma. Você pode tentar few-shot prompting, mas isso não garante a saída desejada.
      Para garantir, o melhor é fazer fine-tuning com um dataset pequeno, de cerca de mil exemplos, e melhorar a partir daí.
    • Depende do objetivo, mas consegui reproduzir um formato de saída específico fazendo fine-tuning do modelo LLaMA2 base em vez de um modelo com RLHF.
      Meu caso de uso era uma tarefa simples de extração/síntese de informações a partir de texto, mais do que escrita criativa. O modelo base pode não ser adequado para todas as tarefas.
    • Basta dar um prompt para que o modelo sempre coloque a resposta ou o código dentro de uma string content ou de JSON.
      Se for JSON, dá para identificar o início e o fim, então é só remover o que estiver fora do JSON.
  • Fico feliz em ver um texto como este. Houve tanta discussão online sobre customização de modelos que este artigo consegue eliminar bastante ruído.
    Também gostei da metodologia de avaliação, e o texto parece bem escrito.

  • É estranho que LoRA e treinamento quantizado não sejam tratados com mais seriedade. São muito mais baratos, levam menos tempo e há bastante evidência de que funcionam bem.
    Não acho que devam ser deixados como uma opção extra para tentar mais tarde.

  • Fico feliz em ver que uma tarefa parecida com NER teve o melhor desempenho. Eu estava prestes a fazer testes semelhantes para comparar com um modelo BERT ajustado.
    Tenho curiosidade sobre qual foi o custo de treinamento dessa tarefa.

    • Sou coautor do artigo. Os dados de treinamento do ViGGO têm cerca de 5,1 mil linhas, e treinamos com tamanho de bloco 512.
      O tamanho de bloco poderia ser reduzido, mas era mais fácil não mudar o código, então deixamos assim. O 7B levou cerca de 15 minutos por época em 16xA10G, e o 13B levou cerca de 25 minutos. Portanto, o custo on-demand por época é de cerca de US$ 7,2 para o 7B e US$ 12 para o 13B. Esses valores consideram apenas o tempo usado no treinamento e não incluem o tempo de inicialização/encerramento do cluster.
    • Boa pergunta. É uma pena que não tenham indicado quanto tempo levaram as 10 épocas, pois assim seria possível calcular o custo. Melhor ainda teria sido publicar tanto o tempo quanto o custo.
      O texto diz que usaram 16xA10G para o 7B e o 13B, e 32xA10G para o 70B, distribuídas em quatro instâncias g5.48xlarge. Com Ray, não é necessário obter A100s para fazer fine-tuning de todos os parâmetros desses modelos, e o mesmo processo é repetido para cada tarefa. No dataset GSM8k, eles mostram uma execução de exemplo com comprimento de contexto 512 e 3,7 milhões de tokens efetivos por época.
      Eles dizem que treinaram por até 10 épocas e escolheram o checkpoint com a menor perplexidade no conjunto de validação.
  • Uma dificuldade é que, para criar um dataset customizado grande o suficiente, você precisa de algo como um pequeno exército de pessoas ou de um modelo existente muito forte.
    No fim, há uma boa chance de ter que usar a OpenAI, mas gerar material de treinamento para outro modelo com a OpenAI viola os termos. Tenho curiosidade se isso já chegou a virar processo. As pessoas simplesmente consideram injusto e ignoram?

    • Isso não vale para todas as tarefas. Em muitas tarefas de processamento de linguagem natural, basta reformatar dados existentes para o formato de LLM.
    • Há algum motivo para não ignorar os termos? Na pior das hipóteses, você perde o acesso.
  • Tenho visto mais exemplos de NER ultimamente, e me pergunto por que não usar spaCy para essas tarefas.

    • O spaCy não funciona bem com dados de treinamento multilíngues, e já vi ele quebrar de maneiras mais numerosas e estranhas do que modelos da família transformers.
    • Estou pensando numa abordagem em que um modelo caro rotula os dados e depois, no esquema professor/aluno, treinamos um modelo menor como SpaCy ou BERT para equilibrar custo e velocidade.
    • Uso modelos da família BERT ajustados para NER, mas gostaria de fazer uma comparação de desempenho.
  • Trabalho na Anyscale.
    Como este blog parece ter recebido uma boa atenção, planejamos incluí-lo no Ray Summit: https://raysummit.anyscale.com/agenda
    Se tiverem ideias sobre que tipos de conteúdo gostariam de ver mais no Ray Summit, seria ótimo saber.

  • Dizem que, para 3,5 milhões de tokens, o 7B leva cerca de 14 minutos por época, e o 13B, cerca de 26 minutos por época.
    Para o 7B e o 13B, parece ser necessário no mínimo 1xg5.16xlarge como nó head e 15xg5.4xlarge como nós worker; fico curioso sobre quanto isso custaria na AWS.

  • Tenho curiosidade se dá para fazer fine-tuning local do Llama-2 em um M1 Ultra 64 GB. A maioria dos materiais usa a nuvem ou Nvidia CUDA no Linux, então seria bom ter alguma referência.

    • Acho que não. Uso um M1 Max 64 GB, e algumas inferências rodam razoavelmente bem.
      Para treinamento, pretendo comprar alguns créditos do RunPod, e acho que deve ser possível com algumas dezenas de dólares.