mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:45:21 +08:00
[V1] Optimize the overhead of rewinding (#14905)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
8a5a9b70d7
commit
faa0275730
@ -1032,17 +1032,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
for i, generator in self.input_batch.generators.items():
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token.
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user