-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathtrain.py
146 lines (122 loc) · 5.42 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from lib.proc_init_utils import initialise_tpu; initialise_tpu('v3-32')
import einops as op
from functools import partial
import jax
from jax import Array
from jax.experimental.multihost_utils import process_allgather
import jax.numpy as jnp
import jax.random as rand
import jax_smi
import math
import optax
import signal
import time
from transformers import LlamaTokenizer
from tqdm import tqdm
from typing import Any, Callable
import wandb
from lib.data import TrainData
from lib.dataloader import LlamaDataLoader
from lib.gsm_data import GSMDataset, gsm_collate_fn_train
from lib.llama import Llama, RotaryValues, forward_llama, init_llama, make_rotary_values
# from lib.llama import model_config_dummy as model_config
from lib.llama import model_config_llama2_7B as model_config
from lib.loss import cross_entropy_loss
from lib.multihost_utils import shard_model_params
from lib.param_utils import load_params, save_params
is_process_0: bool
params: Llama
optimize: Callable
def load_params_from_disk(path: str) -> Llama:
cpu_device = jax.devices('cpu')[0]
with jax.default_device(cpu_device):
# params = init_llama(key=rand.key(42), model_config=model_config)
params = load_params(path)
params = shard_model_params(params)
return params
def set_save_params_signal():
signal.signal(signal.SIGINT, save_params_signal_handler)
signal.signal(signal.SIGTERM, save_params_signal_handler)
def unset_save_params_signal():
signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_IGN)
def save_params_to_disk() -> None:
unset_save_params_signal()
gathered_params = process_allgather(params)
if is_process_0:
save_params(gathered_params, f'{wandb.run.name}.pickle') # type: ignore
set_save_params_signal()
def save_params_signal_handler(signum, frame):
save_params_to_disk()
print(f'Signal {signum} received. Model params have been successfully saved to disk.')
exit(-1)
@jax.value_and_grad
def train_forward(params: Llama, rotary_values: RotaryValues, data_batch: TrainData, *, key: Array):
seq, seq_mask, labels, labels_mask = data_batch
qk_mask = op.rearrange(jnp.tril(op.einsum(seq_mask, seq_mask, 'B L1, B L2 -> B L1 L2')), 'B L1 L2 -> B 1 1 L1 L2') # causal QK mask
logits, _ = forward_llama(params, seq, qk_mask, rotary_values=rotary_values, key=key, model_config=model_config)
loss = cross_entropy_loss(logits, labels, mask=labels_mask)
return loss
@jax.jit
def train_step(params: Llama, opt_state: Any, rotary_values: RotaryValues, total_loss: Array, data_batch: TrainData, key: Array) -> tuple[Llama, Any, Array, Array, Array]:
key, subkey = rand.split(key)
loss, grads = train_forward(params, rotary_values, data_batch, key=subkey)
total_loss += loss
updates, opt_state = optimize(grads, opt_state, params) # type: ignore
params = optax.apply_updates(params, updates)
return params, opt_state, total_loss, loss, key
def main() -> None:
global is_process_0, params, optimize
lr = 0.00005
batch_size = 6
n_accumulation_steps = 8
max_len = 640
n_epochs = 7
seed = 3407
jax.distributed.initialize()
jax_smi.initialise_tracking()
is_process_0 = jax.process_index() == 0
if is_process_0:
wandb.init(project='llama-finetuning-gsm', config=dict(learning_rate=lr, batch_size=batch_size * n_accumulation_steps, n_epochs=n_epochs, optimiser='adamw'))
key = rand.key(seed, impl='rbg')
tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
dataset = GSMDataset(split='train')
collate_fn = partial(gsm_collate_fn_train, tokenizer, max_len)
dataloader = LlamaDataLoader(dataset, collate_fn, batch_size, seed, drop_last=True) # TODO: setting `drop_last` because the `batch_size` of `rotary_values` is not properly handled
params = load_params_from_disk('llama2-7B.pickle')
set_save_params_signal()
n_steps = math.ceil(len(dataloader) / n_accumulation_steps)
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.,
peak_value=lr,
warmup_steps=n_steps,
decay_steps=n_steps + 1,
end_value=lr,
)
optimizer = optax.adamw(learning_rate=schedule)
optimizer = optax.MultiSteps(optimizer, n_accumulation_steps)
optimize = optimizer.update
opt_state = optimizer.init(params)
rotary_values = make_rotary_values(None, batch_size, max_len, model_config=model_config)
for _ in range(n_epochs):
pbar = tqdm(total=len(dataloader) // n_accumulation_steps)
step_loss = 0.0
total_loss = jnp.zeros(())
if is_process_0:
def report_to_wandb(start_time, opt_state, loss):
nonlocal step_loss
step_loss += loss.item()
if optimizer.has_updated(opt_state):
wandb.log({'train loss': step_loss / n_accumulation_steps, 'time': time.time() - start_time})
step_loss = 0.0
pbar.update()
for step, data_batch in enumerate(dataloader):
start_time = time.time()
params, opt_state, total_loss, loss, key = train_step(params, opt_state, rotary_values, total_loss, data_batch, key)
if is_process_0:
jax.debug.callback(report_to_wandb, start_time, opt_state, loss)
if is_process_0:
wandb.log({'epoch loss': total_loss.item() / (step + 1)})
save_params_to_disk()
if __name__ == '__main__':
main()