[TPU] Support Pathways in vLLM (#21417)

Signed-off-by: wenxindongwork <wenxindong@google.com>
This commit is contained in:
wenxindongwork 2025-07-30 10:02:12 -07:00 committed by GitHub
parent f4135232b9
commit 8f0d516715
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 6 deletions

View File

@ -124,6 +124,7 @@ if TYPE_CHECKING:
VLLM_V1_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_TPU_USING_PATHWAYS: bool = False
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
@ -900,6 +901,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TPU_MOST_MODEL_LEN":
lambda: maybe_convert_int(os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)),
# Whether using Pathways
"VLLM_TPU_USING_PATHWAYS":
lambda: bool("proxy" in os.getenv("JAX_PLATFORMS", "").lower()),
# Allow use of DeepGemm kernels for fused moe ops.
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import traceback
from itertools import chain
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname, supports_xccl
@ -31,20 +31,26 @@ def vllm_version_matches_substr(substr: str) -> bool:
def tpu_platform_plugin() -> Optional[str]:
is_tpu = False
logger.debug("Checking if TPU platform is available.")
# Check for Pathways TPU proxy
if envs.VLLM_TPU_USING_PATHWAYS:
logger.debug("Confirmed TPU platform is available via Pathways proxy.")
return "tpu_commons.platforms.tpu_jax.TpuPlatform"
# Check for libtpu installation
try:
# While it's technically possible to install libtpu on a
# non-TPU machine, this is a very uncommon scenario. Therefore,
# we assume that libtpu is installed if and only if the machine
# we assume that libtpu is installed only if the machine
# has TPUs.
import libtpu # noqa: F401
is_tpu = True
logger.debug("Confirmed TPU platform is available.")
return "vllm.platforms.tpu.TpuPlatform"
except Exception as e:
logger.debug("TPU platform is not available because: %s", str(e))
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
return None
def cuda_platform_plugin() -> Optional[str]: