Skip to content

Commit

Permalink
supporting validation during training
Browse files Browse the repository at this point in the history
  • Loading branch information
joeljang committed Apr 11, 2022
1 parent f7ab149 commit 2687339
Show file tree
Hide file tree
Showing 7 changed files with 737 additions and 106 deletions.
86 changes: 43 additions & 43 deletions Datasets_.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ def __init__(self, tokenizer, type_path, input_length, output_length, args, leng
self.dataset = pd.read_csv('data/evaluation/final/'+ self.args.dataset + '.csv')
elif self.args.dataset=='IL':
self.dataset = pd.read_csv('data/TWiki_Probes/IL.csv')
elif self.args.dataset=='data/twiki_corpus_1024/09':
df1 = pd.read_csv('data/new_data/twiki_probes/0801-0901_unchanged.csv')
df2 = pd.read_csv('data/new_data/twiki_probes/0801-0901_changed.csv')
df3 = pd.read_csv('data/kilt_wikipedia_10000.csv')
df1 = pd.concat([df1, df2])
self.dataset = pd.concat([df1, df3])
elif self.args.dataset=='data/1024/09' or self.args.dataset=='data/1024/08':
unchanged = pd.read_csv('data/new_data/twiki_probes/0801-0901_unchanged.csv')
changed = pd.read_csv('data/new_data/twiki_probes/0801-0901_changed.csv')
openwebtext = pd.read_csv('data/moee_validation/openwebtext/openwebtext_10000.csv')
kilt_wikipedia = pd.read_csv('data/moee_validation/kilt_wikipedia/kilt_wikipedia_10000.csv')
lambada = pd.read_json('data/moee_validation/lambada/lambada_test.jsonl', lines=True)
invariantlama = pd.read_csv('data/moee_validation/IL.csv')
self.dataset = pd.concat([openwebtext, kilt_wikipedia, lambada, invariantlama, unchanged, changed])
elif self.args.dataset=='data/wikipedia_09' or self.args.dataset=='wikipedia_0809' or self.args.dataset=='data/wikipedia_09_gpt2' or self.args.dataset=='wikipedia_0809_gpt2':
df1 = pd.read_csv('data/TWiki_Probes/aligned/0801-0901_unchanged.csv')
df2 = pd.read_csv('data/TWiki_Probes/aligned/0801-0901_updated.csv')
Expand Down Expand Up @@ -115,16 +117,17 @@ def convert_to_features(self, example_batch, index=None):
input_ = example_batch['text']
target_ = example_batch['text']
else:
if 'text' in example_batch:
input_ = example_batch['text']
target_ = example_batch['text']
text = example_batch['text']
s = str(example_batch['subject'])
r = str(example_batch['relation'])
o = str(example_batch['object'])
probe = s + ' ' + r + ' ' + o
if text != text:
input_ = probe
target_ = probe
else:
s = example_batch['subject']
r = example_batch['relation']
o = example_batch['objective']
target_ = s + ' ' + r + ' ' + o
input_ = s + ' ' + r + ' ' + o
#input_nonprompt = ' ' + o
input_ = text
target_ = text
else:
if self.args.mode == 'finetune':
s = example_batch['subject']
Expand All @@ -136,43 +139,40 @@ def convert_to_features(self, example_batch, index=None):
else:
input_ = example_batch['text']
target_ = example_batch['text']
source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length,
source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length,
padding='max_length', truncation=True, return_tensors="pt")
targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length,
padding='max_length', truncation=True, return_tensors="pt")
if input_nonprompt is not None:
input_nonprompt = self.tokenizer.batch_encode_plus([str(input_nonprompt)], max_length=self.input_length,
padding='max_length', truncation=True, return_tensors="pt")
if label_ is not None:
label_ = self.tokenizer.batch_encode_plus([str(label_)], max_length=self.input_length,
padding='max_length', truncation=True, return_tensors="pt")

targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length,
padding='max_length', truncation=True, return_tensors="pt")
return source, targets, input_nonprompt, label_

def input_to_target(self, input):
input_s = input.split(' ')
input = " ".join(input_s[:len(input_s)-1])
target = " " + input_s[len(input_s)-1]
return input, target

def convert_to_features_(self, example_batch, index):
if index < 20000:
input_, target_ = example_batch['text'], example_batch['text']
elif index < 25153:
input_, target_ = self.input_to_target(example_batch['text'])
else:
input_ = example_batch['subject'] + " " + example_batch['relation']
target_ = " " + example_batch['object']
source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, padding='max_length', truncation=True, return_tensors="pt")
targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, padding='max_length', truncation=True, return_tensors="pt")
return source, targets

def __getitem__(self, index):
source, targets, input_nonprompt, label = self.convert_to_features(self.dataset.iloc[index])
source, targets = self.convert_to_features_(self.dataset.iloc[index], index=index)

source_ids = source["input_ids"].squeeze()
target_ids = targets["input_ids"].squeeze()

src_mask = source["attention_mask"].squeeze()
target_mask = targets["attention_mask"].squeeze()

if input_nonprompt is not None:
source_nonprompt_ids = input_nonprompt["input_ids"].squeeze()
source_nonprompt_mask = input_nonprompt["attention_mask"].squeeze()
else:
source_nonprompt_mask = -1
source_nonprompt_ids = -1

if label is not None:
label_ids = label["input_ids"].squeeze()
label_mask = label["attention_mask"].squeeze()
else:
label_ids = -1
label_mask = -1

return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask, "source_nonprompt_ids" : source_nonprompt_ids, "source_nonprompt_mask": source_nonprompt_mask, "label_ids": label_ids, "label_mask": label_mask}
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}

class Pretrain_Chunks(Dataset):
def __init__(self, dataset_name, tokenizer, input_length, output_length, args):
Expand All @@ -194,9 +194,9 @@ def convert_to_features(self, example_batch, index=None):
else:
input_ = example_batch['input']
target_ = example_batch['output']
source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length,
source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length,
padding='max_length', truncation=True, return_tensors="pt")
targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length,
targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length,
padding='max_length', truncation=True, return_tensors="pt")
return source, targets

Expand Down
24 changes: 12 additions & 12 deletions configs/training/full.json
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
{
"input_length" : 512,
"output_length" : 512,
"num_train_epochs" : 1,
"input_length" : 1024,
"output_length" : 1024,
"num_train_epochs" : 4,
"num_files" : 2,
"output_dir" : "outputs/GPT2_large_10_1e-4",
"dataset" : "data/wikipedia_10_gpt2",
"output_dir" : "outputs/GPT2_large_08_1r2.5e-4_full",
"dataset" : "data/1024/08",
"dataset_version" : "full",
"len_data" : 4000000,
"train_batch_size" : 8,
"learning_rate" : 1e-4,
"len_data" : 2697988,
"train_batch_size" : 3,
"learning_rate" : 0.000025,
"model" : "gpt2-large",
"method": "baseline",
"gradient_accumulation_steps" : 1,
"gradient_accumulation_steps" : 20,
"ngpu" : 8,
"num_workers" : 40,
"resume_from_checkpoint" : null,
"accelerator" : "deepspeed_stage_2",
"fp16" : true,
"CUDA_VISIBLE_DEVICES" : "0,1,2,3,4,5,6,7",
"wandb_log": true,
"wandb_project": "ever_changing",
"wandb_run_name" : "GPT2_large_10_1e-4_full",
"wandb_project": "elm_new",
"wandb_run_name" : "GPT2_large_08_1r2.5e-4_full",
"mode" : "pretrain_brute",
"use_lr_scheduling" : true,
"check_validation" : false,
"checkpoint_path" : "outputs/GPT2_large_1e-4_/epoch=1"
"checkpoint_path" : ""
}
28 changes: 0 additions & 28 deletions configs/training/full_.json

This file was deleted.

Loading

0 comments on commit 2687339

Please sign in to comment.