[CI] Fix Pre-commit Issue (#25497)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-09-23 16:09:37 -04:00 committed by GitHub
parent 8bdd8b5c51
commit 8b8a8afc89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2367,7 +2367,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
aux_hidden_states: Optional[list[torch.Tensor]],
spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata,
) -> Union[list[list[int]], torch.Tensor]:
@ -2387,6 +2387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
indices = []
offset = 0
assert spec_decode_metadata is not None
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
@ -2437,6 +2438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
@ -2462,6 +2464,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[token_indices]
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
@ -2897,7 +2900,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len