diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 2695da5778aad..8c099b9531c5f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: @@ -400,6 +400,7 @@ class Platform: cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d8807a72ba2f3..83dd3e9c817af 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch import vllm.envs as envs -from vllm.inputs import PromptType +from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -150,6 +150,7 @@ class TpuPlatform(Platform): cls, prompt: PromptType, params: Union[SamplingParams, PoolingParams], + processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" if isinstance(params, SamplingParams): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 396fe25e2628f..225e78f53eab3 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -202,12 +202,6 @@ class Processor: # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - - from vllm.platforms import current_platform - current_platform.validate_request( - prompt=prompt, - params=params, - ) self._validate_lora(lora_request) self._validate_params(params) if priority != 0: @@ -231,6 +225,12 @@ class Processor: prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash, ) + from vllm.platforms import current_platform + current_platform.validate_request( + prompt=prompt, + params=params, + processed_inputs=processed_inputs, + ) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) self._validate_model_inputs(processed_inputs, lora_request)