mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 00:57:05 +08:00
[Misc] Avoid accessing req_ids inside a loop (#23159)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
5bfe0dea7a
commit
21bcc8263f
@ -1748,6 +1748,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||||
# the sampled tokens back, because there's no direct communication
|
# the sampled tokens back, because there's no direct communication
|
||||||
# between the first-stage worker and the last-stage worker.
|
# between the first-stage worker and the last-stage worker.
|
||||||
|
req_ids = self.input_batch.req_ids
|
||||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||||
if not sampled_ids:
|
if not sampled_ids:
|
||||||
continue
|
continue
|
||||||
@ -1763,7 +1764,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
start_idx:end_idx] = sampled_ids
|
start_idx:end_idx] = sampled_ids
|
||||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||||
self.input_batch.num_tokens[req_idx] = end_idx
|
self.input_batch.num_tokens[req_idx] = end_idx
|
||||||
req_id = self.input_batch.req_ids[req_idx]
|
req_id = req_ids[req_idx]
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
req_state.output_token_ids.extend(sampled_ids)
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
@ -1843,6 +1844,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
elif self.speculative_config.use_eagle():
|
elif self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
# TODO(woosuk): Refactor the loop.
|
# TODO(woosuk): Refactor the loop.
|
||||||
|
req_ids = self.input_batch.req_ids
|
||||||
next_token_ids: list[int] = []
|
next_token_ids: list[int] = []
|
||||||
for i, token_ids in enumerate(sampled_token_ids):
|
for i, token_ids in enumerate(sampled_token_ids):
|
||||||
if token_ids:
|
if token_ids:
|
||||||
@ -1851,7 +1853,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
# Partial prefill (rare case).
|
# Partial prefill (rare case).
|
||||||
# Get the next token id from the request state.
|
# Get the next token id from the request state.
|
||||||
req_id = self.input_batch.req_ids[i]
|
req_id = req_ids[i]
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
@ -1914,6 +1916,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
sampled_token_ids: list[list[int]],
|
sampled_token_ids: list[list[int]],
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
# TODO(woosuk): Optimize.
|
# TODO(woosuk): Optimize.
|
||||||
|
req_ids = self.input_batch.req_ids
|
||||||
draft_token_ids: list[list[int]] = []
|
draft_token_ids: list[list[int]] = []
|
||||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||||
num_sampled_ids = len(sampled_ids)
|
num_sampled_ids = len(sampled_ids)
|
||||||
@ -1924,7 +1927,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Skip requests that require sampling parameters that are not
|
# Skip requests that require sampling parameters that are not
|
||||||
# supported with speculative decoding.
|
# supported with speculative decoding.
|
||||||
req_id = self.input_batch.req_ids[i]
|
req_id = req_ids[i]
|
||||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||||
draft_token_ids.append([])
|
draft_token_ids.append([])
|
||||||
continue
|
continue
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user