From e1b004839a2c6d1f1771b7ab9c97acd0ed0c7aa2 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 16 Apr 2025 18:28:42 +0200 Subject: [PATCH] [Hardware] Add processor inputs to platform validation (#16680) Signed-off-by: Joe Runde --- vllm/platforms/interface.py | 3 ++- vllm/platforms/tpu.py | 3 ++- vllm/v1/engine/processor.py | 12 ++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2695da5778aad..8c099b9531c5f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: @@ -400,6 +400,7 @@ class Platform: cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d8807a72ba2f3..83dd3e9c817af 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import vllm.envs as envs -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -150,6 +150,7 @@ class TpuPlatform(Platform): cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 396fe25e2628f..225e78f53eab3 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -202,12 +202,6 @@ class Processor: # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - - from vllm.platforms import current_platform - current_platform.validate_request( - prompt=prompt, - params=params, - ) self._validate_lora(lora_request) self._validate_params(params) if priority != 0: @@ -231,6 +225,12 @@ class Processor: prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash, ) + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request)