[TPU] make ptxla not imported when using tpu_commons (#23081)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
Chengji Yao 2025-08-18 20:46:42 -07:00 committed by GitHub
parent a4454e9401
commit e9d6a3db69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 78 deletions

View File

@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
@ -18,6 +19,8 @@ USE_RAY = parallel_config = get_current_vllm_config(
logger = init_logger(__name__) logger = init_logger(__name__)
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu(): if current_platform.is_tpu():
import torch_xla import torch_xla
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
@ -25,7 +28,6 @@ if current_platform.is_tpu():
from torch_xla._internal import pjrt from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import ( from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups) create_optimized_replica_groups)
if USE_RAY: if USE_RAY:
from vllm.executor import ray_utils from vllm.executor import ray_utils
@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim) return xm.all_gather(input_, dim=dim)
try: if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import ( from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass

View File

@ -3,7 +3,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch_xla.experimental.custom_kernel # noqa: F401
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
@ -41,6 +40,7 @@ def fused_moe(
gating_output: [*, num_experts] gating_output: [*, num_experts]
""" """
assert expert_map is None, "expert_map is not supported for pallas MoE." assert expert_map is None, "expert_map is not supported for pallas MoE."
import torch_xla.experimental.custom_kernel # noqa: F401
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1] hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel() num_tokens = hidden_states.shape[:-1].numel()

View File

@ -207,8 +207,13 @@ class DefaultModelLoader(BaseModelLoader):
) )
if current_platform.is_tpu(): if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that from vllm.platforms.tpu import USE_TPU_COMMONS
# not too many ops are accumulated in the XLA program.
if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step`
# requently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model
# as xm
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator): def _xla_weights_iterator(iterator: Generator):

View File

@ -24,6 +24,8 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
USE_TPU_COMMONS = False
class TpuPlatform(Platform): class TpuPlatform(Platform):
_enum = PlatformEnum.TPU _enum = PlatformEnum.TPU
@ -201,6 +203,7 @@ class TpuPlatform(Platform):
try: try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
TpuPlatform = TpuCommonsPlatform # type: ignore TpuPlatform = TpuCommonsPlatform # type: ignore
USE_TPU_COMMONS = True
except ImportError: except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform") logger.info("tpu_commons not found, using vLLM's TpuPlatform")
pass pass

View File

@ -5,12 +5,6 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"uint8": torch.uint8, "uint8": torch.uint8,
} }
try:
import tpu_commons # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int, num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
"int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
@ -313,46 +358,6 @@ def write_to_kv_cache(
kv_cache.copy_(new_kv_cache) kv_cache.copy_(new_kv_cache)
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
# We can move this function to a common utils file if it's also useful for other # We can move this function to a common utils file if it's also useful for other
# hardware. # hardware.
def dtype_bits(dtype: torch.dtype): def dtype_bits(dtype: torch.dtype):

View File

@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class.""" """A TPU worker class."""
import os import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -21,19 +19,27 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
class TPUWorker: class TPUWorker:
@ -325,9 +331,7 @@ class TPUWorker:
ensure_kv_transfer_initialized(vllm_config) ensure_kv_transfer_initialized(vllm_config)
try: if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker from tpu_commons.worker import TPUWorker as TPUCommonsWorker
TPUWorker = TPUCommonsWorker # type: ignore TPUWorker = TPUCommonsWorker # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
pass