diff --git a/Datasets_.py b/Datasets_.py index 708c584..48e2ab2 100644 --- a/Datasets_.py +++ b/Datasets_.py @@ -58,12 +58,14 @@ def __init__(self, tokenizer, type_path, input_length, output_length, args, leng self.dataset = pd.read_csv('data/evaluation/final/'+ self.args.dataset + '.csv') elif self.args.dataset=='IL': self.dataset = pd.read_csv('data/TWiki_Probes/IL.csv') - elif self.args.dataset=='data/twiki_corpus_1024/09': - df1 = pd.read_csv('data/new_data/twiki_probes/0801-0901_unchanged.csv') - df2 = pd.read_csv('data/new_data/twiki_probes/0801-0901_changed.csv') - df3 = pd.read_csv('data/kilt_wikipedia_10000.csv') - df1 = pd.concat([df1, df2]) - self.dataset = pd.concat([df1, df3]) + elif self.args.dataset=='data/1024/09' or self.args.dataset=='data/1024/08': + unchanged = pd.read_csv('data/new_data/twiki_probes/0801-0901_unchanged.csv') + changed = pd.read_csv('data/new_data/twiki_probes/0801-0901_changed.csv') + openwebtext = pd.read_csv('data/moee_validation/openwebtext/openwebtext_10000.csv') + kilt_wikipedia = pd.read_csv('data/moee_validation/kilt_wikipedia/kilt_wikipedia_10000.csv') + lambada = pd.read_json('data/moee_validation/lambada/lambada_test.jsonl', lines=True) + invariantlama = pd.read_csv('data/moee_validation/IL.csv') + self.dataset = pd.concat([openwebtext, kilt_wikipedia, lambada, invariantlama, unchanged, changed]) elif self.args.dataset=='data/wikipedia_09' or self.args.dataset=='wikipedia_0809' or self.args.dataset=='data/wikipedia_09_gpt2' or self.args.dataset=='wikipedia_0809_gpt2': df1 = pd.read_csv('data/TWiki_Probes/aligned/0801-0901_unchanged.csv') df2 = pd.read_csv('data/TWiki_Probes/aligned/0801-0901_updated.csv') @@ -115,16 +117,17 @@ def convert_to_features(self, example_batch, index=None): input_ = example_batch['text'] target_ = example_batch['text'] else: - if 'text' in example_batch: - input_ = example_batch['text'] - target_ = example_batch['text'] + text = example_batch['text'] + s = str(example_batch['subject']) + r = str(example_batch['relation']) + o = str(example_batch['object']) + probe = s + ' ' + r + ' ' + o + if text != text: + input_ = probe + target_ = probe else: - s = example_batch['subject'] - r = example_batch['relation'] - o = example_batch['objective'] - target_ = s + ' ' + r + ' ' + o - input_ = s + ' ' + r + ' ' + o - #input_nonprompt = ' ' + o + input_ = text + target_ = text else: if self.args.mode == 'finetune': s = example_batch['subject'] @@ -136,21 +139,32 @@ def convert_to_features(self, example_batch, index=None): else: input_ = example_batch['text'] target_ = example_batch['text'] - source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, + source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, padding='max_length', truncation=True, return_tensors="pt") - targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, - padding='max_length', truncation=True, return_tensors="pt") - if input_nonprompt is not None: - input_nonprompt = self.tokenizer.batch_encode_plus([str(input_nonprompt)], max_length=self.input_length, - padding='max_length', truncation=True, return_tensors="pt") - if label_ is not None: - label_ = self.tokenizer.batch_encode_plus([str(label_)], max_length=self.input_length, - padding='max_length', truncation=True, return_tensors="pt") - + targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, + padding='max_length', truncation=True, return_tensors="pt") return source, targets, input_nonprompt, label_ + def input_to_target(self, input): + input_s = input.split(' ') + input = " ".join(input_s[:len(input_s)-1]) + target = " " + input_s[len(input_s)-1] + return input, target + + def convert_to_features_(self, example_batch, index): + if index < 20000: + input_, target_ = example_batch['text'], example_batch['text'] + elif index < 25153: + input_, target_ = self.input_to_target(example_batch['text']) + else: + input_ = example_batch['subject'] + " " + example_batch['relation'] + target_ = " " + example_batch['object'] + source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, padding='max_length', truncation=True, return_tensors="pt") + targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, padding='max_length', truncation=True, return_tensors="pt") + return source, targets + def __getitem__(self, index): - source, targets, input_nonprompt, label = self.convert_to_features(self.dataset.iloc[index]) + source, targets = self.convert_to_features_(self.dataset.iloc[index], index=index) source_ids = source["input_ids"].squeeze() target_ids = targets["input_ids"].squeeze() @@ -158,21 +172,7 @@ def __getitem__(self, index): src_mask = source["attention_mask"].squeeze() target_mask = targets["attention_mask"].squeeze() - if input_nonprompt is not None: - source_nonprompt_ids = input_nonprompt["input_ids"].squeeze() - source_nonprompt_mask = input_nonprompt["attention_mask"].squeeze() - else: - source_nonprompt_mask = -1 - source_nonprompt_ids = -1 - - if label is not None: - label_ids = label["input_ids"].squeeze() - label_mask = label["attention_mask"].squeeze() - else: - label_ids = -1 - label_mask = -1 - - return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask, "source_nonprompt_ids" : source_nonprompt_ids, "source_nonprompt_mask": source_nonprompt_mask, "label_ids": label_ids, "label_mask": label_mask} + return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask} class Pretrain_Chunks(Dataset): def __init__(self, dataset_name, tokenizer, input_length, output_length, args): @@ -194,9 +194,9 @@ def convert_to_features(self, example_batch, index=None): else: input_ = example_batch['input'] target_ = example_batch['output'] - source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, + source = self.tokenizer.batch_encode_plus([str(input_)], max_length=self.input_length, padding='max_length', truncation=True, return_tensors="pt") - targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, + targets = self.tokenizer.batch_encode_plus([str(target_)], max_length=self.output_length, padding='max_length', truncation=True, return_tensors="pt") return source, targets diff --git a/configs/training/full.json b/configs/training/full.json index ef392eb..855631f 100644 --- a/configs/training/full.json +++ b/configs/training/full.json @@ -1,17 +1,17 @@ { - "input_length" : 512, - "output_length" : 512, - "num_train_epochs" : 1, + "input_length" : 1024, + "output_length" : 1024, + "num_train_epochs" : 4, "num_files" : 2, - "output_dir" : "outputs/GPT2_large_10_1e-4", - "dataset" : "data/wikipedia_10_gpt2", + "output_dir" : "outputs/GPT2_large_08_1r2.5e-4_full", + "dataset" : "data/1024/08", "dataset_version" : "full", - "len_data" : 4000000, - "train_batch_size" : 8, - "learning_rate" : 1e-4, + "len_data" : 2697988, + "train_batch_size" : 3, + "learning_rate" : 0.000025, "model" : "gpt2-large", "method": "baseline", - "gradient_accumulation_steps" : 1, + "gradient_accumulation_steps" : 20, "ngpu" : 8, "num_workers" : 40, "resume_from_checkpoint" : null, @@ -19,10 +19,10 @@ "fp16" : true, "CUDA_VISIBLE_DEVICES" : "0,1,2,3,4,5,6,7", "wandb_log": true, - "wandb_project": "ever_changing", - "wandb_run_name" : "GPT2_large_10_1e-4_full", + "wandb_project": "elm_new", + "wandb_run_name" : "GPT2_large_08_1r2.5e-4_full", "mode" : "pretrain_brute", "use_lr_scheduling" : true, "check_validation" : false, - "checkpoint_path" : "outputs/GPT2_large_1e-4_/epoch=1" + "checkpoint_path" : "" } \ No newline at end of file diff --git a/configs/training/full_.json b/configs/training/full_.json deleted file mode 100644 index 000cde8..0000000 --- a/configs/training/full_.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "input_length" : 1024, - "output_length" : 1024, - "num_train_epochs" : 1, - "num_files" : 2, - "output_dir" : "outputs/GPT2_large_10_1e-4", - "dataset" : "data/wikipedia_10_gpt2", - "dataset_version" : "full", - "len_data" : 4000000, - "train_batch_size" : 8, - "learning_rate" : 1e-4, - "model" : "gpt2-large", - "method": "baseline", - "gradient_accumulation_steps" : 1, - "ngpu" : 8, - "num_workers" : 40, - "resume_from_checkpoint" : null, - "accelerator" : "deepspeed_stage_2", - "fp16" : true, - "CUDA_VISIBLE_DEVICES" : "0,1,2,3,4,5,6,7", - "wandb_log": true, - "wandb_project": "ever_changing", - "wandb_run_name" : "GPT2_large_10_1e-4_full", - "mode" : "pretrain_brute", - "use_lr_scheduling" : true, - "check_validation" : false, - "checkpoint_path" : "outputs/GPT2_large_1e-4_/epoch=1" -} \ No newline at end of file diff --git a/models/GPT2_Model_.py b/models/GPT2_Model_.py index b88df97..c0929ae 100644 --- a/models/GPT2_Model_.py +++ b/models/GPT2_Model_.py @@ -1,5 +1,5 @@ import pytorch_lightning as pl - +from models import utils from transformers import ( Adafactor, GPT2LMHeadModel, @@ -20,6 +20,7 @@ import math import os import csv +import torch.nn.functional as F from models.GPT2_Model_Kadapter import GPT2LMHeadModel as GPT2_Kadapter from models.GPT2_Model_LoRA import GPT2LMHeadModel as GPT2_Lora @@ -247,14 +248,14 @@ def _generative_step(self, batch, batch_idx): self.unchanged_loss += loss average_loss = self.unchanged_loss / self.unchanged ppl = torch.exp(average_loss) - self.log('UnL_ppl', ppl, prog_bar=True, logger=True) + self.log('Un_ppl', ppl, prog_bar=True, logger=True) print('Un_ppl', ppl) elif (batch_idx < (8713//(self.hparams.eval_batch_size * self.hparams.n_gpu))): self.changed +=1 self.changed_loss += loss average_loss = self.changed_loss / self.changed ppl = torch.exp(average_loss) - self.log('Ch_ppl', ppl, prog_bar=True, logger=True) + self.log('C_ppl', ppl, prog_bar=True, logger=True) print('C_ppl', ppl) else: self.wikipedia +=1 @@ -280,11 +281,114 @@ def on_train_epoch_end(self): if self.hparams.method=='mixreview': train_set = self.train_dataloader().dataset self.epoch+=1 - + def validation_step(self, batch, batch_idx): - if self.hparams.mode == 'finetune': - return self._generative_step_finetune(batch, batch_idx) - return self._generative_step(batch, batch_idx) + loss = self._step(batch) + ppl = torch.exp(loss) + self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + if (batch_idx < (10000//(self.hparams.eval_batch_size * self.hparams.n_gpu))): + self.log('openwebtext_ppl', ppl, prog_bar=True, logger=True) + elif (batch_idx < (20000//(self.hparams.eval_batch_size * self.hparams.n_gpu))): + self.log('kilt_wikipedia_ppl', ppl, prog_bar=True, logger=True) + elif (batch_idx < (25153//(self.hparams.eval_batch_size * self.hparams.n_gpu))): + #self.log('lambada_ppl', ppl, prog_bar=True, logger=True) + self.predict_step(padding_length=self.hparams.max_input_length,task='lambada', batch=batch, batch_idx=batch_idx) + elif (batch_idx < (41291//(self.hparams.eval_batch_size * self.hparams.n_gpu))): + #self.log('lama_ppl', ppl, prog_bar=True, logger=True) + self.predict_step(padding_length=self.hparams.max_input_length,task='lama', batch=batch, batch_idx=batch_idx) + elif (batch_idx < (48226//(self.hparams.eval_batch_size * self.hparams.n_gpu))): + #self.log('Un_ppl', ppl, prog_bar=True, logger=True) + self.predict_step(padding_length=self.hparams.max_input_length,task='Unchanged', batch=batch, batch_idx=batch_idx) + else: + #self.log('C_ppl', ppl, prog_bar=True, logger=True) + self.predict_step(padding_length=self.hparams.max_input_length,task='Changed', batch=batch, batch_idx=batch_idx) + + def get_rid_of_pad(self, tokens): + while tokens[0]==-100 or tokens[0]==50259: + tokens.pop(0) + return tokens + + def predict_step(self, padding_length, task, batch, batch_idx): + source_ids = batch["source_ids"].tolist() + target_ids = batch["target_ids"].tolist() + batch_size = len(source_ids) + batch_loss = 0 + batch_acc = 0 + batch_f1 = 0 + inps = [] + cont_toks_list = [] + inplens = [] + for i in range(batch_size): + if source_ids[i]==target_ids[i]: + context_enc = source_ids[i][:padding_length-10] + continuation_enc = target_ids[i][padding_length-10:] + else: + context_enc = source_ids[i] + continuation_enc = self.get_rid_of_pad(target_ids[i]) + #if len(continuation_enc) > 10: + # continuation_enc = continuation_enc[len(continuation_enc)-10:] + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + #inp = torch.tensor( + # (context_enc + continuation_enc)[-(self.max_length+1):][:-1], + # dtype=torch.long + #).to(self.device) + inp = torch.tensor( + (context_enc + continuation_enc)[-(padding_length):][:-1], + dtype=torch.long + ).to(self.device) + inplen, = inp.shape + cont = continuation_enc + + # since in _collate we make sure length is descending, the longest is always the first one. + #padding_length = padding_length if padding_length is not None else inplen + # pad length from seq to padding_length + inp = torch.cat([ + inp, # [seq] + torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] + ], dim=0) + inps.append(inp.unsqueeze(0)) # [1, padding_length] + cont_toks_list.append(cont) + inplens.append(inplen) + + batched_inps = torch.cat(inps, dim=0) # [batch, padding_length + multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] + for logits, inp, inplen, cont_toks \ + in zip(multi_logits, inps, inplens, cont_toks_list): + + # Slice to original seq length + contlen = len(cont_toks) + original_logits = logits + logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + predicted = self.ids_to_clean_text(greedy_tokens) + ground_truth = self.ids_to_clean_text(cont_toks) + em = self.exact_match_score(predicted[0], ground_truth[0]) + f1 = self._f1_score(predicted[0], ground_truth[0]) + + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + # Answer: (log prob, is-exact-match) + loss = -float(logits.sum()) + if bool(max_equal) or em==1: + batch_acc+=1 + batch_loss += loss + batch_f1 += f1 + + batch_loss_avg = batch_loss / batch_size + batch_acc_avg = batch_acc / batch_size + batch_f1_avg = batch_f1 / batch_size + self.log(f'{task}_loss', batch_loss_avg, prog_bar=True, logger=True) + self.log(f'{task}_acc', batch_acc_avg, prog_bar=True, logger=True) + self.log(f'{task}_f1', batch_f1_avg, prog_bar=True, logger=True) + return def configure_optimizers(self, train_len=None): "Prepare optimizer and schedule (linear warmup and decay)" @@ -363,7 +467,7 @@ def configure_optimizers(self, train_len=None): denomniator = (self.hparams.n_gpu * self.hparams.gradient_accumulation_steps) steps_per_epoch = ( len_data // denomniator ) + 1 - schedule_scale_factor = 8 + schedule_scale_factor = 1 total_num_steps = ( steps_per_epoch * self.hparams.num_train_epochs ) * self.hparams.num_files * schedule_scale_factor print(f'total number of steps : {total_num_steps}') @@ -398,4 +502,363 @@ def val_dataloader(self): def test_dataloader(self): test_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="test", args=self.hparams) - return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False) \ No newline at end of file + return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False) + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_token_id + + @property + def max_length(self): + try: + return self.model.config.n_ctx + except AttributeError: + # gptneoconfig doesn't have n_ctx apparently + return self.model.config.max_position_embeddings + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + # TODO: fix multi-gpu + return self.batch_size_per_gpu # * gpus + + @property + def device(self): + # TODO: fix multi-gpu + return self._device + + def tok_encode(self, string: str): + return self.tokenizer.encode(string, add_special_tokens=False) + + def tok_decode(self, tokens): + return self.tokenizer.decode(tokens) + + def _model_call(self, inps): + """ + inps: a torch tensor of shape [batch, sequence] + the size of sequence may vary from call to call + returns: a torch tensor of shape [batch, sequence, vocab] with the + logits returned from the model + """ + with torch.no_grad(): + res = self.model(inps) + return res[0][:, :, :50257] + + def _model_generate(self, context, max_length, eos_token_id): + return self.model.generate( + context, + max_length=max_length, + eos_token_id=eos_token_id, + do_sample=False + ) + + def loglikelihood(self, requests): + new_reqs = [] + for context, continuation in requests: + if context == "": + # end of text as context + context_enc = [self.eot_token_id] + else: + context_enc = self.tok_encode(context) + + continuation_enc = self.tok_encode(continuation) + new_reqs.append(((context, continuation), context_enc, continuation_enc)) + + return self._loglikelihood_tokens(new_reqs) + + def loglikelihood_rolling(self, requests): + # TODO: Implement caching once we've confirmed the perplexity implementation + # TODO: automatic batch size detection for vectorization + + loglikelihoods = [] + for string, in tqdm(requests): + rolling_token_windows = list(map(utils.make_disjoint_window, utils.get_rolling_token_windows( + token_list=self.tok_encode(string), + prefix_token=self.eot_token_id, + max_seq_len=self.max_length, + context_len=1, + ))) + + rolling_token_windows = [(None,) + x for x in rolling_token_windows] + + # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for + # that + string_nll = self._loglikelihood_tokens(rolling_token_windows, disable_tqdm=True) + + # discard is_greedy + string_nll = [x[0] for x in string_nll] + + string_nll = sum(string_nll) + loglikelihoods.append(string_nll) + + return loglikelihoods + + def _loglikelihood_tokens(self, requests, disable_tqdm=False): + # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + + toks = x[1] + x[2] + return -len(toks), tuple(toks) + + + # TODO: automatic (variable) batch size detection for vectorization + reord = utils.Reorderer(requests, _collate) + for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): + inps = [] + cont_toks_list = [] + inplens = [] + + padding_length = None + + # because vectorizing is annoying, we first convert each (context, continuation) pair to padded + # tensors, then we pack them together into a batch, call the model, and then pick it all apart + # again because vectorizing is annoying + + for _, context_enc, continuation_enc in chunk: + # sanity check + assert len(context_enc) > 0 + assert len(continuation_enc) > 0 + assert len(continuation_enc) <= self.max_length + + # how this all works: + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # gpt2 \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + # when too long to fit in context, truncate from the left + inp = torch.tensor( + (context_enc + continuation_enc)[-(self.max_length+1):][:-1], + dtype=torch.long + ).to(self.device) + inplen, = inp.shape + + cont = continuation_enc + + # since in _collate we make sure length is descending, the longest is always the first one. + padding_length = padding_length if padding_length is not None else inplen + + # pad length from seq to padding_length + inp = torch.cat([ + inp, # [seq] + torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] + ], dim=0) + + inps.append(inp.unsqueeze(0)) # [1, padding_length] + cont_toks_list.append(cont) + inplens.append(inplen) + + batched_inps = torch.cat(inps, dim=0) # [batch, padding_length + multi_logits = F.log_softmax(self._model_call(batched_inps), dim=-1).cpu() # [batch, padding_length, vocab] + # Make prediction directory if not exist + pred_dir = self.pred_log.split('/')[0] + isExist = os.path.exists(pred_dir) + if not isExist: + os.makedirs(pred_dir) + #Write prediction + with open(self.pred_log, 'a', newline='') as writefile: + writer = csv.writer(writefile) + for (cache_key, _, _), logits, inp, inplen, cont_toks \ + in zip(chunk, multi_logits, inps, inplens, cont_toks_list): + + # Slice to original seq length + contlen = len(cont_toks) + original_logits = logits + logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab] + + # Check if per-token argmax is exactly equal to continuation + greedy_tokens = logits.argmax(dim=-1) + cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0) # [1, seq] + max_equal = (greedy_tokens == cont_toks).all() + lines = "".join(self.ids_to_clean_text_(inp)) + predicted = self.ids_to_clean_text_(greedy_tokens) + ground_truth = self.ids_to_clean_text_(cont_toks) + if max_equal: + writer.writerow([lines, ground_truth, predicted, "CORRECT"]) + else: + writer.writerow([lines, ground_truth, predicted, "WRONG"]) + # Obtain log-probs at the corresponding continuation token indices + # last_token_slice = logits[:, -1, :].squeeze(0).tolist() + logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] + # Answer: (log prob, is-exact-match) + answer = (float(logits.sum()), bool(max_equal)) + # partial caching + """ + if cache_key is not None: + self.cache_hook.add_partial("loglikelihood", cache_key, answer) + """ + res.append(answer) + + return reord.get_original(res) + + def greedy_until(self, requests): + # TODO: implement fully general `until` that handles untils that are + # multiple tokens or that span multiple tokens correctly + + # TODO: extract to TokenizedLM? + res = [] + + def _collate(x): + toks = self.tok_encode(x[0]) + return len(toks), x[0] + + reord = utils.Reorderer(requests, _collate) + + for context, until in tqdm.tqdm(reord.get_reordered()): + if isinstance(until, str): + until = [until] + + primary_until, = self.tok_encode(until[0]) + + context_enc = torch.tensor([self.tok_encode(context)[self.max_gen_toks - self.max_length:]]).to(self.device) + + cont = self._model_generate(context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until) + + s = self.tok_decode(cont[0].tolist()[context_enc.shape[1]:]) + + for term in until: + s = s.split(term)[0] + """ + # partial caching + self.cache_hook.add_partial("greedy_until", (context, until), s) + """ + res.append(s) + + return reord.get_original(res) + + @utils.positional_deprecated + def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): + """ Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param provide_description: bool + Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method + :param rnd: random.Random + The pseudo-random number generator used to randomly sample examples. + WARNING: This is currently a required arg although it's optionalized with a default `None`. + :param description: str + The task's description that will be prepended to the fewshot examples. + :returns: str + The fewshot context. + """ + assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" + assert not provide_description, ( + "The `provide_description` arg will be removed in future versions. To prepend " + "a custom description to the context, supply the corresponding string via the " + "`description` arg." + ) + if provide_description is not None: + # nudge people to not specify it at all + print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + + description = description + "\n\n" if description else "" + + if num_fewshot == 0: + labeled_examples = "" + else: + # for sets with no training docs, draw from other set *but ensure no overlap with current doc* + if self.has_training_docs(): + fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) + else: + if self._fewshot_docs is None: + self._fewshot_docs = list( + self.validation_docs() if self.has_validation_docs() else self.test_docs() + ) + + fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = "\n\n".join( + [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] + ) + "\n\n" + + example = self.doc_to_text(doc) + return description + labeled_examples + example + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return False + + def training_docs(self): + return jsonlines.open('data/triviaqa/unfiltered-web-train.jsonl') + + def validation_docs(self): + return jsonlines.open('data/triviaqa/unfiltered-web-dev.jsonl') + + def fewshot_context(self, doc, num_fewshot, provide_description=None, rnd=None, description=None): + """ Returns a fewshot context string that is made up of a prepended description + (if provided), the `num_fewshot` number of examples, and an appended prompt example. + + :param doc: str + The document as returned from training_docs, validation_docs, or test_docs. + :param num_fewshot: int + The number of fewshot examples to provide in the returned context string. + :param provide_description: bool + Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method + :param rnd: random.Random + The pseudo-random number generator used to randomly sample examples. + WARNING: This is currently a required arg although it's optionalized with a default `None`. + :param description: str + The task's description that will be prepended to the fewshot examples. + :returns: str + The fewshot context. + """ + assert rnd is not None, "A `random.Random` generator argument must be provided to `rnd`" + assert not provide_description, ( + "The `provide_description` arg will be removed in future versions. To prepend " + "a custom description to the context, supply the corresponding string via the " + "`description` arg." + ) + if provide_description is not None: + # nudge people to not specify it at all + print("WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict") + + description = description + "\n\n" if description else "" + + if num_fewshot == 0: + labeled_examples = "" + else: + # for sets with no training docs, draw from other set *but ensure no overlap with current doc* + if self.has_training_docs(): + fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) + else: + if self._fewshot_docs is None: + self._fewshot_docs = list( + self.validation_docs() if self.has_validation_docs() else self.test_docs() + ) + + fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1) + + # get rid of the doc that's the one we're evaluating, if it's in the fewshot + fewshotex = [x for x in fewshotex if x != doc][:num_fewshot] + + labeled_examples = "\n\n".join( + [self.doc_to_text(doc) + self.doc_to_target(doc) for doc in fewshotex] + ) + "\n\n" + + example = self.doc_to_text(doc) + return description + labeled_examples + example \ No newline at end of file diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..5a2ad46 --- /dev/null +++ b/models/utils.py @@ -0,0 +1,189 @@ +import os +import pathlib +import re +import collections +import functools +import inspect +import sys +import pytest +from typing import List +from models import utils + +class ExitCodeError(Exception): + pass + + +def sh(x): + if os.system(x): + raise ExitCodeError() + + +def simple_parse_args_string(args_string): + """ + Parses something like + args1=val1,arg2=val2 + Into a dictionary + """ + args_string = args_string.strip() + if not args_string: + return {} + arg_list = args_string.split(",") + args_dict = {} + for arg in arg_list: + k, v = arg.split("=") + args_dict[k] = v + return args_dict + +def join_iters(iters): + for iter in iters: + yield from iter + + +def chunks(iter, n): + arr = [] + for x in iter: + arr.append(x) + if len(arr) == n: + yield arr + arr = [] + + if arr: yield arr + +def group(arr, fn): + res = collections.defaultdict(list) + + for ob in arr: + res[fn(ob)].append(ob) + + return list(res.values()) + +def general_detokenize(string): + string = string.replace(" n't", "n't") + string = string.replace(" )", ")") + string = string.replace("( ", "(") + string = string.replace("\" ", "\"") + string = string.replace(" \"", "\"") + string = re.sub(r" (['.,])", r"\1", string) + return string + + +def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): + """ + - context_len allows for a rolling window context, allowing each prediction window to potentially + condition on some context + :param token_list: list + List of tokens to be PREDICTED + :param max_seq_len: int + max_seq_len of model (or max_seq_len we want to use) + :param context_len: int + Amount of desired token context for prediction. Needs to be at least 1. + :param prefix_token: token + Dummy token like so the first token has something to condition on + :return: generator + Generator of tuples + (input_tokens, pred_tokens) + Note: Score only the last len(pred_tokens) logits of the LM + """ + assert 1 <= context_len <= max_seq_len + if not token_list: + return + # +1 offset, going from input->preds + pred_len = max_seq_len - context_len + 1 + predicted = 0 + + # Special handling for first window: predict all tokens + first_seq_len = min(max_seq_len, len(token_list)) + yield ( + [prefix_token] + token_list[:first_seq_len - 1], + token_list[:first_seq_len] + ) + predicted += first_seq_len + + while predicted < len(token_list): + window_pred_len = min(len(token_list) - predicted, pred_len) + window_end = predicted + window_pred_len + + yield ( + token_list[window_end - max_seq_len - 1:window_end - 1], + token_list[window_end - window_pred_len:window_end], + ) + predicted += window_pred_len + +def make_disjoint_window(pair): + """ Takes output from get_rolling_token_windows and makes the context not overlap with the continuation """ + + a, b = pair + + return a[:-(len(b) - 1)], b + +class Reorderer: + def __init__(self, arr, fn): + self.size = len(arr) + arr = list(enumerate(arr)) + arr = group(arr, lambda x: fn(x[1])) + arr = [ + ([y[0] for y in x], x[0][1]) for x in arr + ] + arr.sort(key=lambda x: fn(x[1])) + + self.arr = arr + + + def get_reordered(self): + return [x[1] for x in self.arr] + + def get_original(self, newarr): + res = [None] * self.size + cov = [False] * self.size + + for (inds, _), v in zip(self.arr, newarr): + for ind in inds: + res[ind] = v + cov[ind] = True + + assert all(cov) + + return res + +def positional_deprecated(fn): + """ + A decorator to nudge users into passing only keyword args (`kwargs`) to the + wrapped function, `fn`. + """ + @functools.wraps(fn) + def _wrapper(*args, **kwargs): + if len(args) != 1 if inspect.ismethod(fn) else 0: + print(f"WARNING: using {fn.__name__} with positional arguments is " + "deprecated and will be disallowed in a future version of " + "lm-evaluation-harness!") + return fn(*args, **kwargs) + return _wrapper + +@positional_deprecated +def find_test_root(start_path: pathlib.Path) -> pathlib.Path: + """ + Search upward in the directory tree to a maximum of three layers + to find and return the package root (containing the 'tests' folder) + """ + cur_path = start_path.resolve() + max_layers = 3 + for _ in range(max_layers): + if (cur_path / 'tests' / 'test_version_stable.py').exists(): + return cur_path + else: + cur_path = cur_path.parent.resolve() + raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\ + f"of {start_path}") + +@positional_deprecated +def run_task_tests(task_list: List[str]): + """ + Find the package root and run the tests for the given tasks + """ + package_root = find_test_root(start_path=pathlib.Path(__file__)) + task_string = ' or '.join(task_list) + args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}'] + sys.path.append(str(package_root)) + pytest_return_val = pytest.main(args) + if pytest_return_val: + raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a201830..699821f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,10 @@ nltk google-cloud-storage deepspeed boto3 -rouge \ No newline at end of file +rouge + +torch==1.9.0+cu111 +torchvision==0.10.0+cu111 +torchaudio==0.9.0 +# package location +--find-links https://download.pytorch.org/whl/torch_stable.html \ No newline at end of file diff --git a/run.py b/run.py index a342e1b..d048997 100644 --- a/run.py +++ b/run.py @@ -39,7 +39,7 @@ def set_seed(seed): #Init configs that are not given if 'grad_norm' not in hparam: - hparam.grad_norm = 0.5 + hparam.grad_norm = 1.0 if 'weight_decay' not in hparam: hparam.weight_decay = 0.01 if 'output_log' not in hparam: @@ -101,7 +101,7 @@ def set_seed(seed): num_workers=hparam.num_workers, resume_from_checkpoint=hparam.resume_from_checkpoint, use_lr_scheduling = hparam.use_lr_scheduling, - val_check_interval = 1.0, + val_check_interval = 0.25, fp16=hparam.fp16, opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties max_grad_norm=hparam.grad_norm, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default @@ -115,7 +115,7 @@ def set_seed(seed): #Setting different val & checkpoint saving config for mode if args.mode=='pretrain_brute': - saving_epoch = args.num_files + saving_epoch = 1 else: saving_epoch = 1 @@ -135,18 +135,19 @@ def set_seed(seed): # Setting Flags for pytorch lightning trainer. Details: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#trainer-flags train_params = dict( - accumulate_grad_batches=args.gradient_accumulation_steps, - gpus=args.n_gpu, - max_epochs=int(args.num_train_epochs * args.num_files), - precision= 16 if args.fp16 else 32, - amp_backend="native", - resume_from_checkpoint=args.resume_from_checkpoint, - gradient_clip_val=args.max_grad_norm, - enable_checkpointing=checkpoint_callback, - check_val_every_n_epoch= saving_epoch, + accumulate_grad_batches = args.gradient_accumulation_steps, + gpus = args.n_gpu, + max_epochs =int(args.num_train_epochs * args.num_files), + precision = 16 if args.fp16 else 32, + amp_backend = "native", + resume_from_checkpoint = args.resume_from_checkpoint, + gradient_clip_val = args.max_grad_norm, + enable_checkpointing = checkpoint_callback, + #check_val_every_n_epoch = saving_epoch, + val_check_interval = args.val_check_interval, logger = wandb_logger, callbacks = callbacks, - strategy=args.accelerator + strategy = args.accelerator ) if 't5' in args.model_name_or_path: Model = load_model('T5')