-
Notifications
You must be signed in to change notification settings - Fork 0
/
templates.py
116 lines (95 loc) · 5.08 KB
/
templates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import json
import pickle
import random
import warnings
from itertools import product
from pathlib import Path
from data import DATASET_TO_DATACLASS, SEPS, BIG_SEPS
from evaluate import evaluate_setup
from models import load_generator
from utils import parse_args, get_results_torch, save_results_torch, save_results_nirvana
class Template:
def __init__(self, inp_verbalizer, out_verbalizer, sep, big_sep=""):
self.inp_verbalizer = inp_verbalizer
if "{}" not in self.inp_verbalizer:
raise ValueError("inp_verbalizer must contain {} for formatting")
self.out_verbalizer = out_verbalizer
if "{}" not in self.out_verbalizer:
raise ValueError("out_verbalizer must contain {} for formatting")
self.sep = sep
self.big_sep = big_sep
def __repr__(self):
return "".join([self.inp_verbalizer, self.sep, self.out_verbalizer, self.big_sep])
def __str__(self):
return "".join([self.inp_verbalizer, self.sep, self.out_verbalizer, self.big_sep])
def __hash__(self):
return hash(self.__str__())
def toJSON(self):
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)
def generate_random_templates(dataset_name, n=10, seed=59, with_big_seps=False):
dataset_dataclass = DATASET_TO_DATACLASS[dataset_name]
inp_verbalizers = dataset_dataclass.input_verbalizers
out_verbalizers = dataset_dataclass.output_verbalizers
random.seed(seed)
random_templates = []
for _ in range(n):
inp_verbalizer = random.choice(inp_verbalizers)
out_verbalizer = random.choice(out_verbalizers)
sep = random.choice(SEPS)
big_sep = random.choice(BIG_SEPS) if with_big_seps else ''
random_templates.append(Template(inp_verbalizer, out_verbalizer, sep, big_sep))
return random_templates
def get_templates(dataset, n_shots, n_templates=10, templates_path=None, templates_seed=59):
if templates_path is None:
templates = generate_random_templates(dataset,
with_big_seps=n_shots > 0,
n=n_templates,
seed=templates_seed)
else:
with open(templates_path, "rb") as templates_file:
templates = pickle.load(templates_file)
return templates
if __name__ == '__main__':
args = parse_args()
for model in args.models:
generator = load_generator(model, cache_dir=args.cache_dir, precision=args.precision,
local_files_only=args.local_files_only, device_map=args.device_map,
)
for dataset, seed, prediction_method, selection_method in product(
args.dataset, args.seed, args.prediction_method, args.examples_selection_method):
if selection_method == '0-shot':
num_shots_range = [0]
else:
num_shots_range = args.num_shots
for num_shots in num_shots_range:
res_dir = args.save_dir if args.save_dir else Path("results/template_selection", dataset)
if prediction_method in ["channel", "calibrate"]:
if not args.labels_loss:
warnings.warn(f"Using {prediction_method} with labels_loss set to False is highly discouraged, "
f"setting to True.")
labels_loss = True
else:
labels_loss = args.labels_loss
method_name = f"{prediction_method}_{labels_loss}"
name = f"{model}_formats_stats"
if num_shots > 0:
name += "_zero_shot"
name += f"_{method_name}.out"
results = get_results_torch(save_dir=res_dir, name=name)
inp_verbalizers, out_verbalizers, seps = dataset_templates(dataset)
templates = (Template(inp, out, sep, big_sep)
for inp, out, sep, big_sep in product(inp_verbalizers, out_verbalizers, seps, BIG_SEPS))
for template in templates:
if template in results:
continue
evaluation_result = evaluate_setup(dataset=dataset, generator=generator, seed=seed,
num_shots=num_shots,
prediction_method=prediction_method, labels_loss=labels_loss,
selection_method=selection_method,
example_ids=args.example_ids, examples_path=args.examples_path,
batch_size=args.eval_batch_size,
template=template,
)
results[template] = evaluation_result["scores"]
save_results_torch(res_obj=results, name=name, save_dir=res_dir)
save_results_nirvana()