diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9b0345a6aa3ad..634f955207fac 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1748,6 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. + req_ids = self.input_batch.req_ids for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): if not sampled_ids: continue @@ -1763,7 +1764,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx - req_id = self.input_batch.req_ids[req_idx] + req_id = req_ids[req_idx] req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) @@ -1843,6 +1844,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. + req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): if token_ids: @@ -1851,7 +1853,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Partial prefill (rare case). # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) @@ -1914,6 +1916,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. + req_ids = self.input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -1924,7 +1927,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = self.input_batch.req_ids[i] + req_id = req_ids[i] if req_id in self.input_batch.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue