forked from JonasGeiping/cramming
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
149 lines (120 loc) · 7.09 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Script to evaluate a pretrained model."""
import torch
import hydra
import time
import datetime
import logging
from collections import defaultdict
import cramming
import evaluate
log = logging.getLogger(__name__)
def main_downstream_process(cfg, setup):
"""This function controls the central routine."""
local_time = time.time()
tokenizer, cfg_arch, model_file = cramming.utils.find_pretrained_checkpoint(cfg)
tasks = cramming.prepare_task_dataloaders(tokenizer, cfg.eval, cfg.impl)
metrics = dict()
stats = defaultdict(list)
# Start the clocks now:
for task_name, task in tasks.items():
cfg.eval.steps = len(task["trainloader"]) * cfg.eval.epochs
log.info(f"Finetuning task {task_name} with {task['num_classes']} classes for {cfg.eval.steps} steps.")
# Prepare model for finetuning:
model = cramming.construct_model(cfg_arch, tokenizer.vocab_size, downstream_classes=task["num_classes"])
model_engine, _, _, _ = cramming.load_backend(model, None, tokenizer, cfg.eval, cfg.impl, setup=setup)
model_engine.load_checkpoint(cfg_arch, model_file)
try:
assert task_name != "record"
metric = evaluate.load(task["details"]["collection"], task_name, cache_dir=cfg.impl.path)
except (FileNotFoundError, AssertionError): # no specific metric downloadable from evaluate, construct directly
targets = [evaluate.load(metric_name, cache_dir=cfg.impl.path) for metric_name in task["details"]["target_metrics"]]
metric = evaluate.CombinedEvaluations(targets)
# Launch training
model_engine.train(cfg.eval.eval_in_train_mode)
loss_vals = []
for epoch in range(cfg.eval.epochs):
train_time = time.time()
for step, batch in enumerate(task["trainloader"]):
# Heavy lifting is moved to engines
device_batch = model_engine.to_device(batch, keys=["input_ids", "labels", "attention_mask"])
loss = model_engine.step(device_batch)
loss_vals.append(loss.detach())
if cfg.dryrun:
break
metrics[task_name] = validate(model_engine, task["validloader"], metric, setup, cfg)
stats[f"{task_name}_epoch"] += [epoch]
stats[f"{task_name}_loss"] += [loss.item()]
stats[f"{task_name}_avg_loss"] += [torch.stack(loss_vals).mean().item()] # Smoothed loss
loss_vals = []
current_lr = model_engine.optimizer.param_groups[0]["lr"]
log_msg = f"Train loss {loss.item():2.4f} at step {step} with lr {current_lr:.5f}. "
log_msg += f"[Avg: {stats[f'{task_name}_avg_loss'][-1]:2.4f}] after epoch {epoch}."
stats[f"{task_name}_train_time"] += [(time.time() - train_time)]
estimated_train_finish = str(datetime.timedelta(seconds=stats[f"{task_name}_train_time"][-1] * cfg.eval.epochs))
tokens_per_second = (step + 1) * cfg.eval.max_seq_length * cfg.impl.microbatch_size / stats[f"{task_name}_train_time"][-1]
log_msg += (
f" Perf: {stats[f'{task_name}_train_time'][-1]/60:2.4f}min per epoch ({tokens_per_second:.0f}t/s). "
f"Estimated Total Train: {estimated_train_finish}."
)
for name, metric_val in metrics[task_name].items():
stats[f"{task_name}_{name}"] += [metric_val]
log.info(log_msg)
msg_metrics = " ".join([f"{k}: {v:2.4f}" for k, v in metrics[task_name].items()])
log.info(f"Validation metric is {msg_metrics} after epoch {epoch}.")
cramming.utils.wandb_log(stats, cfg)
if cfg.dryrun:
break
# Launch extra testing if extra validation set exists (as with MNLI-mismatched):
if task["extra_validloader"] is not None:
extra_eval_metric = validate(model_engine, task["extra_validloader"], metric, setup, cfg)
# metrics[task_name + "extra"] = extra_eval_metric
metrics[task_name].update({f"{k}_extra": v for k, v in extra_eval_metric.items()})
for name, metric_val in extra_eval_metric.items():
stats[f"{task_name}_{name}_extra"] += [metric_val]
msg_metrics = " ".join([f"{k}: {v:2.4f}" for k, v in extra_eval_metric.items()])
log.info(f"Extra validation metric is {msg_metrics} after finetuning.")
cramming.utils.wandb_log({f"{task_name}_{k}_extra": [v] for k, v in extra_eval_metric.items()}, cfg)
# Check average metric over all tasks:
target_metrics = []
for task_name, task in tasks.items():
target_metric_names = task["details"]["target_metrics"]
for metric_name in target_metric_names:
target_metrics.append(metrics[task_name][metric_name])
metrics[f"{cfg.eval.name}_amean"] = torch.as_tensor(target_metrics).mean().item()
metrics[f"{cfg.eval.name}_hmean"] = torch.as_tensor(target_metrics).pow(-1).mean().pow(-1).item()
log.info(f"Overall average metric on evaluation {cfg.eval.name} is {metrics[f'{cfg.eval.name}_amean']:.2f}.")
cramming.utils.wandb_log(
{f"{cfg.eval.name}_amean": [metrics[f"{cfg.eval.name}_amean"]], f"{cfg.eval.name}_hmean": [metrics[f"{cfg.eval.name}_hmean"]]},
cfg,
)
# Save to summary:
if cramming.utils.is_main_process():
cramming.utils.save_summary("downstream", cfg, stats, time.time() - local_time, setup)
return metrics # will be dumped into yaml
@torch.no_grad()
def validate(model_engine, validloader, metric, setup, cfg):
"""Evaluate on validation set."""
model_engine.eval()
for step, batch in enumerate(validloader):
device_batch = model_engine.to_device(batch, keys=["input_ids", "labels", "attention_mask"])
_, predictions = model_engine.forward_inference(**device_batch)
if getattr(metric, "config_name", "") != "multirc":
metric.add_batch(predictions=predictions, references=device_batch["labels"])
else: # uuuuuughhhhh, whhyyy multirc
pred_indices = range(step * predictions.shape[0], (step + 1) * predictions.shape[0])
packages = [dict(idx=validloader.index_lookup[pred_indices[i]], prediction=p) for i, p in enumerate(predictions.cpu())]
metric.add_batch(predictions=packages, references=batch["labels"])
if cfg.dryrun and step > 1:
break
try:
eval_metric = metric.compute()
except ValueError: # pearson corr computation will raise errors if metric values are NaN
log.info("Value Error in metrics computation, maybe non-finite values in prediction. Returning backup score.")
eval_metric = metric.compute(predictions=[0, 1], references=[1, 0]) # spoof terrible result if metric computation fails
model_engine.train(cfg.eval.eval_in_train_mode)
return {k: float(v) for k, v in eval_metric.items()} # force float returns
@hydra.main(config_path="cramming/config", config_name="cfg_eval", version_base="1.1")
def launch(cfg):
cramming.utils.main_launcher(cfg, main_downstream_process, job_name="downstream finetuning")
if __name__ == "__main__":
launch()