From 8f0d5167155247934d247eb10ae086108db8d473 Mon Sep 17 00:00:00 2001 From: wenxindongwork <161090399+wenxindongwork@users.noreply.github.com> Date: Wed, 30 Jul 2025 10:02:12 -0700 Subject: [PATCH] [TPU] Support Pathways in vLLM (#21417) Signed-off-by: wenxindongwork --- vllm/envs.py | 5 +++++ vllm/platforms/__init__.py | 18 ++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index ec4b0888d0f40..19bc9156b2586 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -124,6 +124,7 @@ if TYPE_CHECKING: VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None + VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False @@ -900,6 +901,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_TPU_MOST_MODEL_LEN": lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)), + # Whether using Pathways + "VLLM_TPU_USING_PATHWAYS": + lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()), + # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index c13659f8a06e6..56edb8629e45b 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import logging import traceback from itertools import chain from typing import TYPE_CHECKING, Optional +from vllm import envs from vllm.plugins import load_plugins_by_group from vllm.utils import resolve_obj_by_qualname, supports_xccl @@ -31,20 +31,26 @@ def vllm_version_matches_substr(substr: str) -> bool: def tpu_platform_plugin() -> Optional[str]: - is_tpu = False logger.debug("Checking if TPU platform is available.") + + # Check for Pathways TPU proxy + if envs.VLLM_TPU_USING_PATHWAYS: + logger.debug("Confirmed TPU platform is available via Pathways proxy.") + return "tpu_commons.platforms.tpu_jax.TpuPlatform" + + # Check for libtpu installation try: # While it's technically possible to install libtpu on a # non-TPU machine, this is a very uncommon scenario. Therefore, - # we assume that libtpu is installed if and only if the machine + # we assume that libtpu is installed only if the machine # has TPUs. + import libtpu # noqa: F401 - is_tpu = True logger.debug("Confirmed TPU platform is available.") + return "vllm.platforms.tpu.TpuPlatform" except Exception as e: logger.debug("TPU platform is not available because: %s", str(e)) - - return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None + return None def cuda_platform_plugin() -> Optional[str]: