Skip to content

Commit

Permalink
Faster & memory-efficient logprobs calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Dec 3, 2023
1 parent c26d450 commit aa1031a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions trlx/models/modeling_nemo_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def loss_func(model_output):
start = batch.query_tensors.shape[1]
end = start + response_length

label_logprobs = logprobs_of_labels(logits[:, :-1, :], inputs[:, 1:])
label_logprobs = logprobs_of_labels(logits, inputs[:, 1:])
label_logprobs = label_logprobs[:, start:end]

advantages, returns = self.ppo_config.get_advantages_and_returns(
Expand Down Expand Up @@ -1079,11 +1079,11 @@ def ppo_postprocess(model_output):
# to save memory

if run_policy_model and compute_logprobs:
logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
logprobs = logprobs_of_labels(logits, tokens[:, 1:])
return logprobs, dict(logprobs=logprobs, values=values)

if run_reference_model and compute_logprobs:
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], tokens[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits, tokens[:, 1:])
return ref_logprobs, dict(ref_logprobs=ref_logprobs)

return logits, {"logits": logits, "values": values, "ref_logits": ref_logits}
Expand Down
12 changes: 6 additions & 6 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]:

logits = outputs.logits
values_pred = outputs.value
logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:])
logprobs = logprobs_of_labels(logits, decoder_input_ids[:, 1:])
mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)
start = 0
end = start + response_length
Expand All @@ -181,7 +181,7 @@ def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]:
logits = outputs.logits
values_pred = outputs.value
values_pred = values_pred[:, :-1]
logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
logprobs = logprobs_of_labels(logits, tokens[:, 1:])

start = query_tensors.shape[1] - 1
end = start + response_length
Expand Down Expand Up @@ -438,12 +438,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
ref_logits = ref_logits.to(device)

if self.config.model.model_arch_type == "seq2seq":
logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])
logprobs = logprobs_of_labels(logits, sample_outputs[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits, sample_outputs[:, 1:])
else:
# NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])
logprobs = logprobs_of_labels(logits, all_tokens[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits, all_tokens[:, 1:])

n_samples: int = samples.shape[0]

Expand Down

0 comments on commit aa1031a

Please sign in to comment.