mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-15 02:37:03 +08:00
[V1][Spec Decode] Update target_logits in place for rejection sampling (#15427)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
a09ad90a72
commit
25f560a62c
@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids_tensor (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
|
||||
'''
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_probs = compute_probs(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
@ -252,8 +255,8 @@ def compute_probs(
|
||||
replace_from=GREEDY_TEMPERATURE,
|
||||
replace_to=1,
|
||||
)
|
||||
# TODO(woosuk): Consider using in-place op to reduce memory usage.
|
||||
logits = logits / temperature.unsqueeze(-1)
|
||||
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
|
||||
logits.div_(temperature.unsqueeze(-1))
|
||||
|
||||
# Get expanded top_k and top_p tensors.
|
||||
top_k = None
|
||||
|
||||
@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.model.sample(
|
||||
logits=bonus_logits,
|
||||
@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# TODO(woosuk): Optimize the memory usage.
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user