From 33c63f263d588f6db142dc4d44cdfc16671b2f39 Mon Sep 17 00:00:00 2001 From: zhuhaoran Date: Fri, 12 Dec 2025 22:59:57 +0800 Subject: [PATCH] fix: use num_draft_tokens to trim draft_token_ids_cpu Signed-off-by: zhuhaoran --- vllm/v1/worker/gpu_input_batch.py | 11 ++++++++--- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 786dfe95106e6..ee2f6b9964e55 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -966,7 +966,9 @@ class InputBatch: ] def update_async_spec_token_ids( - self, draft_token_ids_cpu: list[list[int]] | None + self, + draft_token_ids_cpu: list[list[int]] | None, + num_draft_tokens: list[int] | None = None, ) -> None: """ In async scheduling case, update spec_token_ids in sampling metadata @@ -985,11 +987,14 @@ class InputBatch: prev_index = self.prev_req_id_to_index.get(req_id) if prev_index is None: continue - assert prev_index < len(draft_token_ids_cpu) draft_ids = draft_token_ids_cpu[prev_index] if not draft_ids: continue - assert index < len(spec_token_ids) + + if num_draft_tokens is not None: + scheduled_count = num_draft_tokens[index] + assert scheduled_count <= len(draft_ids) + draft_ids = draft_ids[:scheduled_count] spec_token_ids[index].clear() spec_token_ids[index].extend(draft_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f6aa773def973..0da2f082052d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2582,7 +2582,10 @@ class GPUModelRunner( # Update spec_token_ids with real draft tokens from previous step draft_token_ids_cpu = self._get_draft_token_ids_cpu() - self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + self.input_batch.update_async_spec_token_ids( + draft_token_ids_cpu, + num_draft_tokens=spec_decode_metadata.num_draft_tokens, + ) sampler_output = self.rejection_sampler( spec_decode_metadata,