[TPU] make ptxla not imported when using tpu_commons (#23081)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
Chengji Yao 2025-08-18 20:46:42 -07:00 committed by GitHub
parent a4454e9401
commit e9d6a3db69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 78 deletions

View File

@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from .base_device_communicator import DeviceCommunicatorBase
@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config(
logger = init_logger(__name__)
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
return xm.all_gather(input_, dim=dim)
try:
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass

View File

@ -3,7 +3,6 @@
import torch
import torch.nn.functional as F
import torch_xla.experimental.custom_kernel # noqa: F401
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
@ -41,6 +40,7 @@ def fused_moe(
gating_output: [*, num_experts]
"""
assert expert_map is None, "expert_map is not supported for pallas MoE."
import torch_xla.experimental.custom_kernel # noqa: F401
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()

View File

@ -207,16 +207,21 @@ class DefaultModelLoader(BaseModelLoader):
)
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
from vllm.platforms.tpu import USE_TPU_COMMONS
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step`
# requently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model
# as xm
import torch_xla.core.xla_model as xm
weights_iterator = _xla_weights_iterator(weights_iterator)
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()

View File

@ -24,6 +24,8 @@ else:
logger = init_logger(__name__)
USE_TPU_COMMONS = False
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@ -201,6 +203,7 @@ class TpuPlatform(Platform):
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
TpuPlatform = TpuCommonsPlatform # type: ignore
USE_TPU_COMMONS = True
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
pass

View File

@ -5,12 +5,6 @@ from dataclasses import dataclass
from typing import Optional
import torch
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
"uint8": torch.uint8,
}
try:
import tpu_commons # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int, num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
"int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
class PallasAttentionBackend(AttentionBackend):
@ -313,46 +358,6 @@ def write_to_kv_cache(
kv_cache.copy_(new_kv_cache)
@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache
XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache
# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):

View File

@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from typing import Any, Optional
import torch
import torch.distributed
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.config import VllmConfig
@ -21,19 +19,27 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache
logger = init_logger(__name__)
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
class TPUWorker:
@ -325,9 +331,7 @@ class TPUWorker:
ensure_kv_transfer_initialized(vllm_config)
try:
if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
TPUWorker = TPUCommonsWorker # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
pass