-
Notifications
You must be signed in to change notification settings - Fork 318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory efficiency improvement to logprobs_from_logits_v2 #220
Merged
vermouth1992
merged 3 commits into
volcengine:main
from
tyler-romero:tyler/logprobs_of_labels_v2_fix
Feb 8, 2025
Merged
Memory efficiency improvement to logprobs_from_logits_v2 #220
vermouth1992
merged 3 commits into
volcengine:main
from
tyler-romero:tyler/logprobs_of_labels_v2_fix
Feb 8, 2025
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Benchmarks: import time
import torch
def naive_method(logits, input_ids):
log_probs = logits.log_softmax(dim=-1)
return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
def method_1(logits, input_ids): # old logprobs_from_logits_v2 implementation
token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values # log_softmax(logits) = logits - log(sum(exp(logits)))
return token_log_probs
def method_2(logits, input_ids): # compute log_softmax in a loop to reduce peak memory
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
def method_3(logits, input_ids): # combine methods 1 and 2
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids):
token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)
def efficient_method(logits, input_ids): # pull everything out of the loop except logsumexp
token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
token_log_probs = token_logits - logsumexp_values
return token_log_probs
def measure_memory_and_time(func, logits, input_ids):
torch.cuda.reset_peak_memory_stats()
start_time = time.perf_counter()
result = func(logits, input_ids)
end_time = time.perf_counter()
mem_peak = torch.cuda.max_memory_allocated()
return result, end_time - start_time, mem_peak
# Simulated data
torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16
device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()
# Run all methods
naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
# Check equivalence
print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time: {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time: {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time: {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time: {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time: {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
# Results:
# > Max absolute difference (naive and 1): 1.9073486328125e-06
# > Max absolute difference (naive and 2): 0.0
# > Max absolute difference (naive and 3): 1.9073486328125e-06
# > Max absolute difference (naive and efficient): 1.9073486328125e-06
# > Memory consumed by logits: 2147.61 MB
# > Naive method time: 0.036307 sec, Memory peak: 4295.16 MB
# > Method 1 time: 0.134651 sec, Memory peak: 4295.43 MB
# > Method 2 time: 0.012156 sec, Memory peak: 2416.18 MB
# > Method 3 time: 0.001496 sec, Memory peak: 2282.10 MB
# > Efficient method time: 0.000918 sec, Memory peak: 2282.23 MB |
Hi @tyler-romero, Great catch! Could you please put your test cases into our CI systems so that future PRs won't break it. Thanks. You can move your tests to here https://github.com/volcengine/verl/blob/main/tests/gpu_utility/test_torch_functional.py following pytest style. |
Added test, and they're passing locally for me |
vermouth1992
approved these changes
Feb 8, 2025
as12138
pushed a commit
to as12138/verl
that referenced
this pull request
Feb 20, 2025
) Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab` tensor internally to hold the element-wise application of `exp`. However, by applying a loop over `logsumexp`, we can iteratively compute logsumexp outputs. Benchmarks show this uses significantly less memory to compute logprobs. Fix provided, as well as a separate memory-efficient approach for bfloat16 case.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Existing
logprobs_from_logits_v2
doesnt achieve the memory savings it claims. This is becauselogsumexp
still allocates abs*seqlen*vocab
tensor internally to hold the element-wise application ofexp
. However, by applying a loop overlogsumexp
, we can iteratively compute logsumexp outputs.Benchmarks show this uses significantly less memory to compute logprobs.
Fix provided, as well as a separate memory-efficient approach for bfloat16 case.