Skip to content

Support DoReMi training 🚀

Compare
Choose a tag to compare
@NouamaneTazi NouamaneTazi released this 22 Feb 15:31
· 489 commits to main since this release
9f9af42

You might think that one of the key ways to speed up pretraining performance is either by finding more quality data, increasing FLOPs, or changing the model architecture, but actually, these are not the only ways. DoReMi shows that, given the same source of training data, a model using an optimal data mixing strategy could outperform its counterpart with random sampling in at least 70% domains or all domains and downstream evaluations without any knowledge of the downstream evaluation tasks.

DoReMi Blog: https://crfm.stanford.edu/2023/09/14/doremi

Using DoReMi in Nanotron:

(Thanks to @xrsrke)

  • Step 0: Preprocessing data

  • Step 1: Train a small reference model using uniform sampling from each domain (for a given global batch size, you equally sample x samples across all domains, or in some cases, a domain has a smaller amount of samples than other domains. This leads to some domains running out of samples early, so you could enable automatic domain weights based on the token count).

CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_280m_llama.yaml
  • Step 2: Use the trained reference model from step 1 to train an identical model, and use its performance to dynamically tune the domain weights during training.
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 examples/doremi/train_doremi.py --config-file examples/doremi/configs/config_280m_llama_proxy.yaml
  • Step 3: Nanotron saves the domain weights in the model checkpoint. Now, calculate the optimal domain weights by averaging the domain weights across all training steps from step 1: $\bar{\alpha}=\frac{1}{T} \sum_{i=1}^T \alpha_t$.
import torch

domain_weights = torch.load("checkpoints/doremi/proxy-280m-llama/doremi_domain_weights_100000.pt")

total_weights = sum(d["domain_weights"] for d in domain_weights)
avg_weights = total_weights / len(domain_weights)

Then, set these avg_weights in the config of the larger run in the doremi section.

  • Step 4: Use the optimized domain weights from step 3 to train a larger model (could be 10x to 30x larger).
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 examples/doremi/train_reference.py --config-file examples/doremi/configs/config_2.8b_llama_with_tuned_weights.yaml
  • Step 5: Profit 🤑