[V1] Optimize the overhead of rewinding (#14905)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-03-16 20:19:30 -07:00 committed by GitHub
parent 8a5a9b70d7
commit faa0275730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1032,17 +1032,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, req_id in enumerate(self.input_batch.req_ids):
for i, generator in self.input_batch.generators.items():
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.