diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 786dfe95106e6..ee2f6b9964e55 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -966,7 +966,9 @@ class InputBatch: ] def update_async_spec_token_ids( - self, draft_token_ids_cpu: list[list[int]] | None + self, + draft_token_ids_cpu: list[list[int]] | None, + num_draft_tokens: list[int] | None = None, ) -> None: """ In async scheduling case, update spec_token_ids in sampling metadata @@ -985,11 +987,14 @@ class InputBatch: prev_index = self.prev_req_id_to_index.get(req_id) if prev_index is None: continue - assert prev_index < len(draft_token_ids_cpu) draft_ids = draft_token_ids_cpu[prev_index] if not draft_ids: continue - assert index < len(spec_token_ids) + + if num_draft_tokens is not None: + scheduled_count = num_draft_tokens[index] + assert scheduled_count <= len(draft_ids) + draft_ids = draft_ids[:scheduled_count] spec_token_ids[index].clear() spec_token_ids[index].extend(draft_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f6aa773def973..0da2f082052d6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2582,7 +2582,10 @@ class GPUModelRunner( # Update spec_token_ids with real draft tokens from previous step draft_token_ids_cpu = self._get_draft_token_ids_cpu() - self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu) + self.input_batch.update_async_spec_token_ids( + draft_token_ids_cpu, + num_draft_tokens=spec_decode_metadata.num_draft_tokens, + ) sampler_output = self.rejection_sampler( spec_decode_metadata,