From 9ee9d0e274819f6439b2926fba82eecd049267b7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Aug 2025 15:02:07 -0700 Subject: [PATCH] fix Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 17 +++++++++++------ vllm/v1/worker/gpu_worker_states.py | 1 + 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cf8eebf4dc2c4..4ca411e83d4be 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -311,6 +311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None + self._draft_req_ids: Optional[list[str]] = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), @@ -997,13 +998,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # list of tuple (mm_hash, position_info) mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] + req_data = self.requests.req_data[req_id] for mm_input_id in encoder_input_ids: - mm_hash = req_state.mm_hashes[mm_input_id] - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) + mm_hash = req_data.mm_hashes[mm_input_id] + mm_kwargs.append(req_data.mm_kwargs[mm_input_id]) mm_hashes_pos.append( - (mm_hash, req_state.mm_positions[mm_input_id])) + (mm_hash, req_data.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1576,6 +1577,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.spec_decode_metadata, input_batch.spec_decode_common_attn_metadata, ) + self._draft_req_ids = input_batch.req_ids self.eplb_step() @@ -1593,12 +1595,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None - req_ids = self.requests.req_ids if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids = self._draft_token_ids.tolist() else: draft_token_ids = self._draft_token_ids self._draft_token_ids = None + + assert self._draft_req_ids + req_ids = self._draft_req_ids + self._draft_req_ids = None return DraftTokenIds(req_ids, draft_token_ids) def propose_draft_token_ids( @@ -1614,7 +1619,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( - sampled_token_ids) + input_batch, sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 4aa656ce9ab89..ab07237cc1c7d 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -28,6 +28,7 @@ class RequestData: sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] + mm_hashes: list[str] # M-RoPE (only for Qwen2/2.5-VL) mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None