From 25f560a62c4f955672e2c6080b17ab3a48f96201 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 24 Mar 2025 21:04:41 -0700 Subject: [PATCH] [V1][Spec Decode] Update target_logits in place for rejection sampling (#15427) Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 7 +++++-- vllm/v1/worker/gpu_model_runner.py | 9 +++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index e0db9474f61cb..69bc68174d504 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -67,6 +67,7 @@ class RejectionSampler(nn.Module): Shape is [num_tokens, vocab_size]. Here, probabilities from different requests are flattened into a single tensor because this is the shape of the output logits. + NOTE: `target_logits` can be updated in place to save memory. bonus_token_ids_tensor (torch.Tensor): A tensor containing bonus tokens. Shape is [batch_size, 1]. Bonus tokens are added to the end of the sequence if all @@ -83,6 +84,8 @@ class RejectionSampler(nn.Module): ''' assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] + # NOTE(woosuk): `target_logits` can be updated in place inside the + # `compute_probs` function. target_probs = compute_probs( target_logits, metadata.cu_num_draft_tokens, @@ -252,8 +255,8 @@ def compute_probs( replace_from=GREEDY_TEMPERATURE, replace_to=1, ) - # TODO(woosuk): Consider using in-place op to reduce memory usage. - logits = logits / temperature.unsqueeze(-1) + # NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor. + logits.div_(temperature.unsqueeze(-1)) # Get expanded top_k and top_p tensors. top_k = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c6741fdc5d6f4..a85009f1a36a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1059,7 +1059,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampling_metadata=sampling_metadata, ) else: - # TODO(woosuk): Optimize the memory usage. + # When indexing with a tensor (bonus_logits_indices), PyTorch + # creates a new tensor with separate storage from the original + # logits tensor. This means any in-place operations on bonus_logits + # won't affect the original logits tensor. bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] sampler_output = self.model.sample( logits=bonus_logits, @@ -1067,7 +1070,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) bonus_token_ids = sampler_output.sampled_token_ids - # TODO(woosuk): Optimize the memory usage. + # Just like `bonus_logits`, `target_logits` is a new tensor with + # separate storage from the original `logits` tensor. Therefore, + # it is safe to update `target_logits` in place. target_logits = logits[spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( spec_decode_metadata,