Skip to content

Chain-of-thought prompting experiments with BLOOM-7B on GSM-8K dataset

Notifications You must be signed in to change notification settings

aebogdanova/BLOOM-CoT-prompting-experiments

Repository files navigation

Chain-of-thought prompting experiments

Описание задачи

Задача исследования заключается в сравнении двух подходов к chain-of-thought prompting:

  • классический chain-of-thought, основанный на greedy decoding (оригинальная статья здесь),
  • ансамблированный chain-of-thought, или self-consistency подход, основанный на различных методах сэмплирования и агрегирования ответов модели (оригинальная статья здесь)

Эксперименты должны были проводиться на распределенной модели BLOOM-176B, однако в силу низкой скорости инференса модели, не позволяющей в разумные сроки получить сколько-нибудь значимые метрики, было принято решение использовать модель меньшего размера, BLOOM-7B. В целом, имеющийся код позволяет воспроизвести решение и для BLOOM-176B, для этого необходимо лишь указать соответствующее название модели в блоке Preparation ноутбука CoT-prompting-experiments.ipynb:

MODEL_NAME = "bigscience/bloom-petals"

Оценка полученных результатов проводилась на тестовой выборке бенчмарка GSM-8K, включающего математические задачи уровня начальной школы.

Чтобы воспроизвести эксперименты, необходимо запустить ноутбук CoT-prompting-experiments.ipynb и следовать инструкциям, указанным в нем. В директории generated-answers содержатся сохраненные ответы модели, по которым можно воспроизвести метрики, полученные в ходе исследования (инструкция также есть в ноутбуке).

Особенности реализации

В качестве промптов для инференса модели был взят тот же самый набор из 8 примеров задач, который использовали авторы статьи о chain-of-thought prompting и авторы статьи о self-consistency prompting.

В ходе работы удалось провести эксперименты на 500 примерах из тестового датасета:

  • с классическим chain-of-thought prompting:
    • max_new_tokens=140
  • с self-consistency подходом:
    • max_new_tokens=140,
    • temperature=0.5,
    • top_k=40.

В случае с ансамблированным chain-of-thought было сгенерировано по 12 сэмплов на каждый пример из тестового датасета. Агрегация ответов проводилась методом простого большинства (выбирался самый частый ответ).

Результаты

Результаты генерации оценивались метрикой accuracy (долей правильных ответов среди всех ответов модели). Self-consistency подход показал улучшение в метрике, хоть и незначительное.

Method Accuracy
BLOOM-7B CoT-prompting 4.2
Self-consistency 5.0

Если сравнивать полученные результаты с перфомансом других моделей, близких по размеру BLOOM-7B, то можно увидеть, что на данной модели метрики так или иначе близки метрикам, например, 20-миллиардной UL2:

Method Accuracy
BLOOM-7B CoT-prompting 4.2
Self-consistency 5.0
UL2-20B CoT-prompting 4.1
Self-consistency 7.3

Тем не менее, необходимо заметить, что для полноценного анализа полученных результатов проведенных экспериментов, безусловно, недостаточно. В первую очередь, необходимо увеличить количество сэмплов для примеров в экспериментах с self-consistency методом (в оригинальной статье авторы брали по 40 сэмплов, здесь рассмотрено только 12). Это подтверждается тем, что если для CoT-prompting метода метрика BLOOM-7B почти совпадает с метрикой UL2, то для self-consistency подхода разница уже более значима: возможно, при бОльшем количестве сэмплов BLOOM-7B покажет более высокие результаты. Во-вторых, для более корректной репрезентации метрик было бы правильно протестировать промпты на всем тестовом датасете, состоящем из 1319 примеров (в нашем случае взято только 500).

Дальнейшие шаги

Помимо проблем, обозначенных в п.3, для более полного исследования следует также расширить количество экспериментов с self-consistency prompting:

  • изменить параметры генерации: температуру, top-k, попробовать применить top-p; можно использовать, например, те сочетания параметров, которые рассматривали авторы оригинальной статьи:
    • temperature=0.7, top_k=40
    • temperature=0.7, top_k=20
    • temperature=0.3, top_k=40
    • temperature=0.7, no top_k
    • top_p=0.95
    • top_p=0.9
  • исследовать влияние количества сэмплов (reasoning paths) на качество генерации: например, начать с 5-10 и с шагом в 5-10 постепенно увеличивать до 40 (или более),
  • установить repetition-penalty, который, предположительно, может способствовать улучшению качества генерации reasoning paths.

В качестве других идей, которые могут улучшить полученные результаты, можно предложить:

  • использовать разнообразные форматы промптов, как это было представлено в свежей статье Large Language Models Can Self-improve. Несмотря на то, что в данной статье разные форматы промптов использовали для того, чтобы модель не переобучалась при fine-tuning (на сгенерированных ей самой reasoning paths), возможно, разные форматы промптов при prompt engineering также могут дать более разнообразные ответы,
  • при генерации арифметических операций типа The distances are 80 + 150 = 210 miles. The answer is 210. подключать модуль калькулятора, чтобы получать верно посчитанный ответ. При обращении к ответам модели часто можно видеть, что модель при верном рассуждении выдает некорректный ответ именно из-за ошибок при расчетах.

About

Chain-of-thought prompting experiments with BLOOM-7B on GSM-8K dataset

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published