Skip to content

Commit

Permalink
validation metric update
Browse files Browse the repository at this point in the history
  • Loading branch information
CHLee0801 committed Nov 30, 2022
1 parent a8749f6 commit 1281130
Showing 1 changed file with 22 additions and 30 deletions.
52 changes: 22 additions & 30 deletions models/GPT2_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ def __init__(self, hparams):
super(GPT2, self).__init__()
self.save_hyperparameters(hparams)
self.unchanged_loss = 0
self.updated_loss = 0
self.new_loss = 0
self.changed_loss = 0
self.invariant_loss = 0
self.unchanged = 0
self.updated = 0
self.new = 0
self.changed = 0
self.invariant = 0
self.validation = 0
self.validation_loss = 0
Expand Down Expand Up @@ -147,9 +145,9 @@ def calculate_scores(self, predictions, ground_truths):
f1_score /= len(predictions)
return em_score*100, f1_score*100

def get_dataset(self, tokenizer, type_path, args, length=None):
def get_dataset(self, tokenizer, type_path, args, length=None, lama_type=None):
dataset = CustomDataset(tokenizer=tokenizer, type_path=type_path, input_length=args.max_input_length,
output_length=args.max_output_length, args=args, length=length)
output_length=args.max_output_length, args=args, length=length, lama_type=lama_type)
return dataset

def freeze_params(self, model):
Expand Down Expand Up @@ -240,31 +238,23 @@ def _generative_step_finetune(self, batch, batch_idx):
self.log('F1 score', f1_score, prog_bar=True, logger=True)


def _generative_step(self, batch, batch_idx):
def _generative_step(self, batch, batch_idx, dataloader_idx=-1):
loss = self._step(batch)
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

if (batch_idx < (10000//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
if dataloader_idx == 0:
self.unchanged +=1
self.unchanged_loss += loss
average_loss = self.unchanged_loss / self.unchanged
ppl = torch.exp(average_loss)
self.log('UnL_ppl', ppl, prog_bar=True, logger=True)
print('UnL_ppl', ppl)
elif (batch_idx < (15000//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.updated +=1
self.updated_loss += loss
average_loss = self.updated_loss / self.updated
ppl = torch.exp(average_loss)
self.log('UL_ppl', ppl, prog_bar=True, logger=True)
print('UL_ppl', ppl)
elif (batch_idx < (20000//(self.hparams.eval_batch_size * self.hparams.n_gpu))):
self.new +=1
self.new_loss += loss
average_loss = self.new_loss / self.new
self.log('UnC_ppl', ppl, prog_bar=True, logger=True)
print('UnC_ppl', ppl)
elif dataloader_idx == 1:
self.changed +=1
self.changed_loss += loss
average_loss = self.changed_loss / self.changed
ppl = torch.exp(average_loss)
self.log('NL_ppl', ppl, prog_bar=True, logger=True)
print('NL_ppl', ppl)
self.log('C_ppl', ppl, prog_bar=True, logger=True)
print('C_ppl', ppl)
else:
self.invariant +=1
self.invariant_loss += loss
Expand All @@ -290,10 +280,10 @@ def on_train_epoch_end(self):
train_set = self.train_dataloader().dataset
self.epoch+=1

def validation_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx, dataloader_idx=-1):
if self.hparams.mode == 'finetune':
return self._generative_step_finetune(batch, batch_idx)
return self._generative_step(batch, batch_idx)
return self._generative_step_finetune(batch, batch_idx, dataloader_idx)
return self._generative_step(batch, batch_idx, dataloader_idx)

def configure_optimizers(self, train_len=None):
"Prepare optimizer and schedule (linear warmup and decay)"
Expand Down Expand Up @@ -401,9 +391,11 @@ def train_dataloader(self):
return dataloader

def val_dataloader(self):
validation_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams,)
return DataLoader(validation_dataset, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False)

validation_dataset_unchanged = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams, lama_type='unchanged')
validation_dataset_changed = self.get_dataset(tokenizer=self.tokenizer, type_path="validation", args=self.hparams, lama_type='changed')
return [DataLoader(validation_dataset_unchanged, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False),
DataLoader(validation_dataset_changed, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers, shuffle=False),
]
def test_dataloader(self):
test_dataset = self.get_dataset(tokenizer=self.tokenizer, type_path="test", args=self.hparams)

Expand Down

0 comments on commit 1281130

Please sign in to comment.