Skip to content

Commit

Permalink
uploading debug midpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
joeljang committed Apr 18, 2022
1 parent 0556819 commit 450f080
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 47 deletions.
34 changes: 20 additions & 14 deletions Datasets_.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ def __init__(self, tokenizer, type_path, input_length, output_length, args, leng
skip = sorted(random.sample(range(1,total_line+1),total_line-length))
self.dataset = pd.read_csv('data/Wikipedia_Full/wikipedia_08_gpt2/part1.csv', usecols=['text'], skiprows=skip)
else:
openwebtext = pd.read_csv('data/moee_validation/openwebtext/openwebtext_10000.csv')
kilt_wikipedia = pd.read_csv('data/moee_validation/kilt_wikipedia/kilt_wikipedia_10000.csv')
lambada = pd.read_json('data/moee_validation/lambada/lambada_test.jsonl', lines=True)
invariantlama = pd.read_csv('data/moee_validation/IL.csv')
dataset_prefix = 'data/new_moee_validation/'
openwebtext = pd.read_csv(f'{dataset_prefix}openwebtext.csv')
kilt_wikipedia = pd.read_csv(f'{dataset_prefix}kilt_wikipedia.csv')
lambada = pd.read_csv(f'{dataset_prefix}lambada.csv')
lama = pd.read_csv(f'{dataset_prefix}trex_lama.csv')
if self.args.dataset=='data/new_data/twiki_corpus_1024/08' or self.args.dataset=='data/new_data/twiki_corpus_1024/09':
unchanged = pd.read_csv('data/new_data/twiki_probes/0801-0901_unchanged.csv')
changed = pd.read_csv('data/new_data/twiki_probes/0801-0901_changed.csv')
unchanged = pd.read_csv('data/new_data/tempate_mapped/0801-0901_unchanged_final_template2.csv')
changed = pd.read_csv('data/new_data/tempate_mapped/0801-0901_changed_final_template2.csv')
else:
raise Exception(f'the following training data {self.args.dataset} does not have a designated validation dataset')
self.dataset = pd.concat([openwebtext, kilt_wikipedia, lambada, invariantlama, unchanged, changed])
self.dataset = pd.concat([openwebtext, kilt_wikipedia, lambada, lama, unchanged, changed])

print(f'Length of dataset retrieving is.. {len(self.dataset)}')
self.input_length = input_length
Expand All @@ -48,27 +49,32 @@ def input_to_target(self, input):
return input, target

def convert_to_features(self, example_batch, index):
if index < 20000:
if index < 19200:
input_, target_ = example_batch['text'], example_batch['text']
elif index < 25153:
elif index < 24192:
input_, target_ = self.input_to_target(example_batch['text'])
elif index < 52928:
input_ = example_batch['input']
input_ = input_[:-1]
target_ = " " + example_batch['output']
else:
input_ = example_batch['subject'] + " " + example_batch['relation']
target_ = " " + example_batch['object']
input_ = example_batch['input']
target_ = " " + example_batch['target']
task = example_batch['task']
source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, padding='max_length', truncation=True, return_tensors="pt")
targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, padding='max_length', truncation=True, return_tensors="pt")
return source, targets
return source, targets, task

def __getitem__(self, index):
source, targets = self.convert_to_features(self.dataset.iloc[index], index=index)
source, targets, task = self.convert_to_features(self.dataset.iloc[index], index=index)

source_ids = source["input_ids"].squeeze()
target_ids = targets["input_ids"].squeeze()

src_mask = source["attention_mask"].squeeze()
target_mask = targets["attention_mask"].squeeze()

return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}
return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask, "task": task}

class Pretrain_Chunks(Dataset):
def __init__(self, dataset_name, tokenizer, input_length, output_length, args):
Expand Down
18 changes: 9 additions & 9 deletions configs/training/full.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@
"output_length" : 1024,
"num_train_epochs" : 4,
"num_files" : 2,
"output_dir" : "outputs/GPT2_large_08_1r2.5e-4_full",
"output_dir" : "outputs/GPT2_small_08_1r1e-4",
"dataset" : "data/new_data/twiki_corpus_1024/08",
"dataset_version" : "full",
"len_data" : 2697988,
"train_batch_size" : 3,
"learning_rate" : 0.000025,
"model" : "gpt2-large",
"train_batch_size" : 12,
"learning_rate" : 1e-4,
"model" : "gpt2",
"method": "baseline",
"gradient_accumulation_steps" : 20,
"ngpu" : 8,
"gradient_accumulation_steps" : 5,
"ngpu" : 1,
"num_workers" : 40,
"resume_from_checkpoint" : null,
"accelerator" : "deepspeed_stage_2",
"fp16" : true,
"CUDA_VISIBLE_DEVICES" : "0,1,2,3,4,5,6,7",
"CUDA_VISIBLE_DEVICES" : "0",
"wandb_log": true,
"wandb_project": "elm_new",
"wandb_run_name" : "GPT2_large_08_1r2.5e-4_full",
"wandb_run_name" : "GPT2_small_08_1r1e-4_full",
"mode" : "pretrain_brute",
"use_lr_scheduling" : true,
"check_validation" : false,
"checkpoint_path" : ""
}
}
52 changes: 29 additions & 23 deletions models/GPT2_Model_.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, hparams):


self.model.resize_token_embeddings(len(self.tokenizer))
self.tokenizer.padding_side = "left"
#self.tokenizer.padding_side = "left"

self.output_dir = self.hparams.output_dir
if self.hparams.mode=='pretrain_brute':
Expand Down Expand Up @@ -283,29 +283,34 @@ def on_train_epoch_end(self):
self.epoch+=1

def validation_step(self, batch, batch_idx):
loss = self._step(batch)
ppl = torch.exp(loss)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
if (batch_idx < (10000//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.log('openwebtext_ppl', ppl, prog_bar=True, logger=True)
elif (batch_idx < (20000//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.log('kilt_wikipedia_ppl', ppl, prog_bar=True, logger=True)
elif (batch_idx < (25153//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
#self.log('lambada_ppl', ppl, prog_bar=True, logger=True)
self.predict_step(padding_length=self.hparams.max_input_length,task='lambada', batch=batch, batch_idx=batch_idx)
elif (batch_idx < (41291//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
#self.log('lama_ppl', ppl, prog_bar=True, logger=True)
self.predict_step(padding_length=self.hparams.max_input_length,task='lama', batch=batch, batch_idx=batch_idx)
elif (batch_idx < (48226//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
#self.log('Un_ppl', ppl, prog_bar=True, logger=True)
self.predict_step(padding_length=self.hparams.max_input_length,task='Unchanged', batch=batch, batch_idx=batch_idx)
else:
#self.log('C_ppl', ppl, prog_bar=True, logger=True)
self.predict_step(padding_length=self.hparams.max_input_length,task='Changed', batch=batch, batch_idx=batch_idx)
tasks = batch["task"]
first_task = tasks[0]
uniformed = True
for t in tasks:
if t!=first_task:
uniformed = False
if uniformed:
loss = self._step(batch)
ppl = torch.exp(loss)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
if (batch_idx < (9600//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.log('openwebtext_ppl', ppl, prog_bar=True, logger=True)
elif (batch_idx < (19200//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.log('kilt_wikipedia_ppl', ppl, prog_bar=True, logger=True)
elif (batch_idx < (24192//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.predict_step(padding_length=self.hparams.max_input_length,task='lambada', batch=batch, batch_idx=batch_idx)
elif (batch_idx < (52928//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.predict_step(padding_length=self.hparams.max_input_length,task='lama', batch=batch, batch_idx=batch_idx)
elif (batch_idx < (56052//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.predict_step(padding_length=self.hparams.max_input_length,task='Unchanged', batch=batch, batch_idx=batch_idx)
else:
self.predict_step(padding_length=self.hparams.max_input_length,task='Changed', batch=batch, batch_idx=batch_idx)
else:
print(f'The batch {batch_idx} is not uniformed..')

def get_rid_of_pad(self, tokens):
while tokens[0]==-100 or tokens[0]==50259:
tokens.pop(0)
while tokens[-1]==-100 or tokens[-1]==50259:
tokens.pop()
return tokens

def predict_step(self, padding_length, task, batch, batch_idx):
Expand All @@ -323,7 +328,7 @@ def predict_step(self, padding_length, task, batch, batch_idx):
context_enc = source_ids[i][:padding_length-10]
continuation_enc = target_ids[i][padding_length-10:]
else:
context_enc = source_ids[i]
context_enc = self.get_rid_of_pad(source_ids[i])
continuation_enc = self.get_rid_of_pad(target_ids[i])
#if len(continuation_enc) > 10:
# continuation_enc = continuation_enc[len(continuation_enc)-10:]
Expand Down Expand Up @@ -379,6 +384,7 @@ def predict_step(self, padding_length, task, batch, batch_idx):
loss = -float(logits.sum())
if bool(max_equal) or em==1:
batch_acc+=1

batch_loss += loss
batch_f1 += f1

Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def set_seed(seed):
num_workers=hparam.num_workers,
resume_from_checkpoint=hparam.resume_from_checkpoint,
use_lr_scheduling = hparam.use_lr_scheduling,
val_check_interval = 0.05,
val_check_interval = 0.01,
fp16=hparam.fp16,
opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
max_grad_norm=hparam.grad_norm, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
Expand Down

0 comments on commit 450f080

Please sign in to comment.