diff --git a/trlx/models/modeling_nemo_ppo.py b/trlx/models/modeling_nemo_ppo.py index 5ad044130..049d8fa4c 100644 --- a/trlx/models/modeling_nemo_ppo.py +++ b/trlx/models/modeling_nemo_ppo.py @@ -993,7 +993,7 @@ def loss_func(model_output): start = batch.query_tensors.shape[1] end = start + response_length - label_logprobs = logprobs_of_labels(logits[:, :-1, :], inputs[:, 1:]) + label_logprobs = logprobs_of_labels(logits, inputs[:, 1:]) label_logprobs = label_logprobs[:, start:end] advantages, returns = self.ppo_config.get_advantages_and_returns( @@ -1079,11 +1079,11 @@ def ppo_postprocess(model_output): # to save memory if run_policy_model and compute_logprobs: - logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + logprobs = logprobs_of_labels(logits, tokens[:, 1:]) return logprobs, dict(logprobs=logprobs, values=values) if run_reference_model and compute_logprobs: - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits, tokens[:, 1:]) return ref_logprobs, dict(ref_logprobs=ref_logprobs) return logits, {"logits": logits, "values": values, "ref_logits": ref_logits} diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 1a4801aaf..723149ebe 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -163,7 +163,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]: logits = outputs.logits values_pred = outputs.value - logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:]) + logprobs = logprobs_of_labels(logits, decoder_input_ids[:, 1:]) mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device) start = 0 end = start + response_length @@ -181,7 +181,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]: logits = outputs.logits values_pred = outputs.value values_pred = values_pred[:, :-1] - logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:]) + logprobs = logprobs_of_labels(logits, tokens[:, 1:]) start = query_tensors.shape[1] - 1 end = start + response_length @@ -438,12 +438,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logits = ref_logits.to(device) if self.config.model.model_arch_type == "seq2seq": - logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:]) + logprobs = logprobs_of_labels(logits, sample_outputs[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits, sample_outputs[:, 1:]) else: # NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled - logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:]) - ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) + logprobs = logprobs_of_labels(logits, all_tokens[:, 1:]) + ref_logprobs = logprobs_of_labels(ref_logits, all_tokens[:, 1:]) n_samples: int = samples.shape[0]