[Core] Small simplification in GPUModelRunner._update_states() (#26508)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-09 19:53:58 -07:00 committed by GitHub
parent 757fa4a4da
commit aafb99a4d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -708,6 +708,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
req_index = self.input_batch.req_id_to_index.get(req_id)
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
@ -728,19 +729,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
del req_state.output_token_ids[num_output_tokens:]
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is not None:
old_end_idx = self.input_batch.num_tokens_no_spec[req_index]
end_idx = (
self.input_batch.num_prompt_tokens[req_index]
+ num_output_tokens
)
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx
self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = (
False
)
# Update the block IDs.
if not resumed_from_preemption:
@ -749,12 +744,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
block_ids.extend(new_ids)
else:
assert req_index is None
assert new_block_ids is not None
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not