diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0b66b52713e97..fc68e5d63a6e5 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -91,13 +91,19 @@ class TpuPlatform(Platform): parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.tpu_worker.TPUWorker" - else: - if scheduler_config.is_multi_step: + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: parallel_config.worker_cls = \ "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.tpu_worker.TPUWorker" else: parallel_config.worker_cls = \ "vllm.worker.tpu_worker.TPUWorker"