[Misc] Minor code cleanup for _get_prompt_logprobs_dict (#23064)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-17 18:16:03 -07:00 committed by GitHub
parent 0fc8fa751a
commit 8ea0c2753a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1722,7 +1722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output,
scheduler_output.num_scheduled_tokens,
)
# Get the valid generated tokens.
@ -2064,7 +2064,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
@ -2077,8 +2077,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens = num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]