diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0fe5e25d6ab..f78582495814 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2367,7 +2367,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[torch.Tensor], + aux_hidden_states: Optional[list[torch.Tensor]], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: @@ -2387,6 +2387,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: indices = [] offset = 0 + assert spec_decode_metadata is not None for num_draft, tokens in zip( spec_decode_metadata.num_draft_tokens, sampled_token_ids): @@ -2437,6 +2438,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO(woosuk): Support M-RoPE. target_positions = self.positions.gpu[:num_scheduled_tokens] if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1) @@ -2462,6 +2464,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # TODO(woosuk): Support M-RoPE. target_positions = self.positions.gpu[token_indices] if self.use_aux_hidden_state_outputs: + assert aux_hidden_states is not None target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) else: @@ -2897,7 +2900,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert not create_mixed_batch num_reqs = cdiv(num_tokens, max_query_len) assert num_reqs <= max_num_reqs, \ - "Do not capture num_reqs > max_num_reqs for uniform batch" + f"Do not capture num_reqs {num_reqs} > max_num_reqs " \ + f"{max_num_reqs} for uniform batch. Num tokens: " \ + f"{num_tokens}, max_query_len: {max_query_len}" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] = num_tokens % max_query_len