diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 699b273642b07..426bde2201cd4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -27,7 +27,7 @@ class InputBatch: # batch_idx -> num_scheduled_tokens num_scheduled_tokens: np.ndarray total_num_tokens: int - max_num_tokens: int + max_query_len: int num_reqs: int attn_metadata: dict[str, Any] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a289ab35a903a..8fe0633dd138b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -577,7 +577,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get the number of scheduled tokens for each request. tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) + max_num_scheduled_tokens = int(num_scheduled_tokens.max()) prepare_inputs( idx_mapping=idx_mapping_np, @@ -751,7 +751,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_np=idx_mapping_np, num_reqs=num_reqs, total_num_tokens=total_num_scheduled_tokens, - max_num_tokens=max_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata, spec_decode_common_attn_metadata=spec_decode_common_attn_metadata, @@ -1412,10 +1412,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - uniform_decode = (input_batch.max_num_tokens + uniform_decode = (input_batch.max_query_len == self.uniform_decode_query_len and num_scheduled_tokens - == input_batch.num_reqs * input_batch.max_num_tokens) + == input_batch.num_reqs * input_batch.max_query_len) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) cudagraph_runtime_mode, batch_descriptor = \