mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[TPU] make ptxla not imported when using tpu_commons (#23081)
Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
parent
a4454e9401
commit
e9d6a3db69
@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
|
||||
@ -18,16 +19,17 @@ USE_RAY = parallel_config = get_current_vllm_config(
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
if not USE_TPU_COMMONS:
|
||||
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
from torch_xla.distributed.xla_multiprocessing import (
|
||||
create_optimized_replica_groups)
|
||||
if USE_RAY:
|
||||
from vllm.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator(DeviceCommunicatorBase):
|
||||
@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
||||
return xm.all_gather(input_, dim=dim)
|
||||
|
||||
|
||||
try:
|
||||
if USE_TPU_COMMONS:
|
||||
from tpu_commons.distributed.device_communicators import (
|
||||
TpuCommunicator as TpuCommonsCommunicator)
|
||||
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
||||
pass
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
|
||||
|
||||
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
||||
@ -41,6 +40,7 @@ def fused_moe(
|
||||
gating_output: [*, num_experts]
|
||||
"""
|
||||
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
|
||||
@ -207,16 +207,21 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||
# not too many ops are accumulated in the XLA program.
|
||||
import torch_xla.core.xla_model as xm
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
if not USE_TPU_COMMONS:
|
||||
# In PyTorch XLA, we should call `xm.mark_step`
|
||||
# requently so that not too many ops are accumulated
|
||||
# in the XLA program. import torch_xla.core.xla_model
|
||||
# as xm
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
|
||||
@ -24,6 +24,8 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
USE_TPU_COMMONS = False
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
@ -201,6 +203,7 @@ class TpuPlatform(Platform):
|
||||
try:
|
||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||
TpuPlatform = TpuCommonsPlatform # type: ignore
|
||||
USE_TPU_COMMONS = True
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
|
||||
pass
|
||||
|
||||
@ -5,12 +5,6 @@ from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
# Required to register custom ops.
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer, AttentionType)
|
||||
@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"uint8": torch.uint8,
|
||||
}
|
||||
|
||||
try:
|
||||
import tpu_commons # noqa: F401
|
||||
except ImportError:
|
||||
# Lazy import torch_xla
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int, num_slices_per_block: int):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update,
|
||||
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
|
||||
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
|
||||
"int num_slices_per_block)" \
|
||||
"-> Tensor", )
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||
num_kv_update_slices, page_size,
|
||||
num_slices_per_block)
|
||||
return new_kv_cache
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@ -313,46 +358,6 @@ def write_to_kv_cache(
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
||||
"page_size": page_size,
|
||||
"num_slices_per_block": num_slices_per_block
|
||||
})
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
|
||||
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
|
||||
"-> Tensor", )
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||
num_kv_update_slices, page_size,
|
||||
num_slices_per_block)
|
||||
return new_kv_cache
|
||||
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
|
||||
# We can move this function to a common utils file if it's also useful for other
|
||||
# hardware.
|
||||
def dtype_bits(dtype: torch.dtype):
|
||||
|
||||
@ -1,15 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A TPU worker class."""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -21,19 +19,27 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if not USE_TPU_COMMONS:
|
||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||
|
||||
|
||||
class TPUWorker:
|
||||
|
||||
@ -325,9 +331,7 @@ class TPUWorker:
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
try:
|
||||
if USE_TPU_COMMONS:
|
||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||
|
||||
TPUWorker = TPUCommonsWorker # type: ignore
|
||||
except ImportError:
|
||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||
pass
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user