[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-24 21:04:41 -07:00 committed by GitHub
parent a09ad90a72
commit 25f560a62c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
'''
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs = compute_probs(
target_logits,
metadata.cu_num_draft_tokens,
@ -252,8 +255,8 @@ def compute_probs(
replace_from=GREEDY_TEMPERATURE,
replace_to=1,
)
# TODO(woosuk): Consider using in-place op to reduce memory usage.
logits = logits / temperature.unsqueeze(-1)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits.div_(temperature.unsqueeze(-1))
# Get expanded top_k and top_p tensors.
top_k = None

View File

@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
else:
# TODO(woosuk): Optimize the memory usage.
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.model.sample(
logits=bonus_logits,
@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
bonus_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): Optimize the memory usage.
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,