Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory efficiency improvement to logprobs_from_logits_v2 #220

Merged

Conversation

tyler-romero
Copy link
Contributor

@tyler-romero tyler-romero commented Feb 7, 2025

Existing logprobs_from_logits_v2 doesnt achieve the memory savings it claims. This is because logsumexp still allocates a bs*seqlen*vocab tensor internally to hold the element-wise application of exp. However, by applying a loop over logsumexp, we can iteratively compute logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for bfloat16 case.

@tyler-romero
Copy link
Contributor Author

Benchmarks:

import time
import torch

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

# Simulated data
torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

# Run all methods
naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)

# Check equivalence
print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))

# Results:
# > Max absolute difference (naive and 1): 1.9073486328125e-06
# > Max absolute difference (naive and 2): 0.0
# > Max absolute difference (naive and 3): 1.9073486328125e-06
# > Max absolute difference (naive and efficient): 1.9073486328125e-06
# > Memory consumed by logits: 2147.61 MB
# > Naive method time:      0.036307 sec, Memory peak: 4295.16 MB
# > Method 1 time:          0.134651 sec, Memory peak: 4295.43 MB
# > Method 2 time:          0.012156 sec, Memory peak: 2416.18 MB
# > Method 3 time:          0.001496 sec, Memory peak: 2282.10 MB
# > Efficient method time:  0.000918 sec, Memory peak: 2282.23 MB

@tyler-romero tyler-romero marked this pull request as ready for review February 7, 2025 07:57
@vermouth1992
Copy link
Collaborator

Hi @tyler-romero,

Great catch! Could you please put your test cases into our CI systems so that future PRs won't break it. Thanks. You can move your tests to here https://github.com/volcengine/verl/blob/main/tests/gpu_utility/test_torch_functional.py following pytest style.

@tyler-romero
Copy link
Contributor Author

Added test, and they're passing locally for me

@vermouth1992 vermouth1992 merged commit 4b51624 into volcengine:main Feb 8, 2025
11 checks passed
as12138 pushed a commit to as12138/verl that referenced this pull request Feb 20, 2025
)

Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants