forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
knowledge_distillation_single_device.yaml
132 lines (113 loc) · 3.79 KB
/
knowledge_distillation_single_device.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Config for single device knowledge distillation (KD) in knowledge_distillation_single_device.py
# using a LLAMA3 teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
#
# To launch on a single device, run the following command from root:
# tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device
#
# This config works only for training on single device.
# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 64
lora_alpha: 128
lora_dropout: 0.0
teacher_model:
_component_: torchtune.models.llama3_1.llama3_1_8b
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
max_seq_len: null
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-1B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
# Teacher checkpoint
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 4
# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5
# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory
# Logging
output_dir: /tmp/kd_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True
# Environment
device: cuda
dtype: bf16
# Activations Memory
enable_activation_checkpointing: False # True reduces memory
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False
#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs
#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True
#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False
# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1