diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4ab037fdb77ee..c1218801bc077 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -9,20 +9,25 @@ from tpu_info import device from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams, SamplingType from .interface import Platform, PlatformEnum if TYPE_CHECKING: + from typing import TypeAlias + from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams + + ParamsType: TypeAlias = SamplingParams | PoolingParams else: BlockSize = None VllmConfig = None PoolingParams = None AttentionBackendEnum = None + ParamsType = None logger = init_logger(__name__) @@ -203,10 +208,12 @@ class TpuPlatform(Platform): def validate_request( cls, prompt: PromptType, - params: SamplingParams | PoolingParams, + params: ParamsType, processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" + from vllm.sampling_params import SamplingParams, SamplingType + if ( isinstance(params, SamplingParams) and params.sampling_type == SamplingType.RANDOM_SEED