mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:34:58 +08:00
[TPU] make ptxla not imported when using tpu_commons (#23081)
Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
This commit is contained in:
parent
a4454e9401
commit
e9d6a3db69
@ -10,6 +10,7 @@ from torch.distributed import ProcessGroup
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||||
|
|
||||||
from .base_device_communicator import DeviceCommunicatorBase
|
from .base_device_communicator import DeviceCommunicatorBase
|
||||||
|
|
||||||
@ -18,6 +19,8 @@ USE_RAY = parallel_config = get_current_vllm_config(
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if not USE_TPU_COMMONS:
|
||||||
|
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
import torch_xla
|
import torch_xla
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
@ -25,7 +28,6 @@ if current_platform.is_tpu():
|
|||||||
from torch_xla._internal import pjrt
|
from torch_xla._internal import pjrt
|
||||||
from torch_xla.distributed.xla_multiprocessing import (
|
from torch_xla.distributed.xla_multiprocessing import (
|
||||||
create_optimized_replica_groups)
|
create_optimized_replica_groups)
|
||||||
|
|
||||||
if USE_RAY:
|
if USE_RAY:
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
|
|
||||||
@ -94,10 +96,7 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
|||||||
return xm.all_gather(input_, dim=dim)
|
return xm.all_gather(input_, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
try:
|
if USE_TPU_COMMONS:
|
||||||
from tpu_commons.distributed.device_communicators import (
|
from tpu_commons.distributed.device_communicators import (
|
||||||
TpuCommunicator as TpuCommonsCommunicator)
|
TpuCommunicator as TpuCommonsCommunicator)
|
||||||
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
||||||
except ImportError:
|
|
||||||
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
|
||||||
pass
|
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
|
||||||
@ -41,6 +40,7 @@ def fused_moe(
|
|||||||
gating_output: [*, num_experts]
|
gating_output: [*, num_experts]
|
||||||
"""
|
"""
|
||||||
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
assert expert_map is None, "expert_map is not supported for pallas MoE."
|
||||||
|
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_size = hidden_states.shape[-1]
|
hidden_size = hidden_states.shape[-1]
|
||||||
num_tokens = hidden_states.shape[:-1].numel()
|
num_tokens = hidden_states.shape[:-1].numel()
|
||||||
|
|||||||
@ -207,8 +207,13 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||||
# not too many ops are accumulated in the XLA program.
|
|
||||||
|
if not USE_TPU_COMMONS:
|
||||||
|
# In PyTorch XLA, we should call `xm.mark_step`
|
||||||
|
# requently so that not too many ops are accumulated
|
||||||
|
# in the XLA program. import torch_xla.core.xla_model
|
||||||
|
# as xm
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
def _xla_weights_iterator(iterator: Generator):
|
def _xla_weights_iterator(iterator: Generator):
|
||||||
|
|||||||
@ -24,6 +24,8 @@ else:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
USE_TPU_COMMONS = False
|
||||||
|
|
||||||
|
|
||||||
class TpuPlatform(Platform):
|
class TpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.TPU
|
_enum = PlatformEnum.TPU
|
||||||
@ -201,6 +203,7 @@ class TpuPlatform(Platform):
|
|||||||
try:
|
try:
|
||||||
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||||
TpuPlatform = TpuCommonsPlatform # type: ignore
|
TpuPlatform = TpuCommonsPlatform # type: ignore
|
||||||
|
USE_TPU_COMMONS = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
|
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -5,12 +5,6 @@ from dataclasses import dataclass
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_xla.core.xla_builder as xb
|
|
||||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
|
||||||
# Required to register custom ops.
|
|
||||||
from torch.library import impl
|
|
||||||
from torch_xla._internal.jax_workarounds import requires_jax
|
|
||||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
@ -37,6 +31,57 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
|||||||
"uint8": torch.uint8,
|
"uint8": torch.uint8,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tpu_commons # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
# Lazy import torch_xla
|
||||||
|
import torch_xla.core.xla_builder as xb
|
||||||
|
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||||
|
from torch.library import impl
|
||||||
|
from torch_xla._internal.jax_workarounds import requires_jax
|
||||||
|
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||||
|
|
||||||
|
@requires_jax
|
||||||
|
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_update_slices: torch.Tensor,
|
||||||
|
page_size: int, num_slices_per_block: int):
|
||||||
|
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||||
|
new_kv_cache = xb.call_jax(
|
||||||
|
kv_cache_update,
|
||||||
|
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
||||||
|
"page_size": page_size,
|
||||||
|
"num_slices_per_block": num_slices_per_block
|
||||||
|
})
|
||||||
|
return new_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
XLA_LIB.define(
|
||||||
|
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
|
||||||
|
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
|
||||||
|
"int num_slices_per_block)" \
|
||||||
|
"-> Tensor", )
|
||||||
|
|
||||||
|
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||||
|
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_update_slices: torch.Tensor,
|
||||||
|
page_size: int,
|
||||||
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
|
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
||||||
|
num_kv_update_slices, page_size,
|
||||||
|
num_slices_per_block)
|
||||||
|
return new_kv_cache
|
||||||
|
|
||||||
|
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||||
|
def kv_cache_update_op_non_xla(kv: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_update_slices: torch.Tensor,
|
||||||
|
page_size: int,
|
||||||
|
num_slices_per_block: int) -> torch.Tensor:
|
||||||
|
return kv_cache
|
||||||
|
|
||||||
|
|
||||||
class PallasAttentionBackend(AttentionBackend):
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
@ -313,46 +358,6 @@ def write_to_kv_cache(
|
|||||||
kv_cache.copy_(new_kv_cache)
|
kv_cache.copy_(new_kv_cache)
|
||||||
|
|
||||||
|
|
||||||
@requires_jax
|
|
||||||
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
|
||||||
num_slices_per_block: int):
|
|
||||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
|
||||||
new_kv_cache = xb.call_jax(
|
|
||||||
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
|
|
||||||
"page_size": page_size,
|
|
||||||
"num_slices_per_block": num_slices_per_block
|
|
||||||
})
|
|
||||||
return new_kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
XLA_LIB.define(
|
|
||||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
|
|
||||||
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
|
|
||||||
"-> Tensor", )
|
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
|
||||||
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
num_kv_update_slices: torch.Tensor, page_size: int,
|
|
||||||
num_slices_per_block: int) -> torch.Tensor:
|
|
||||||
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
|
|
||||||
num_kv_update_slices, page_size,
|
|
||||||
num_slices_per_block)
|
|
||||||
return new_kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
|
||||||
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
num_kv_update_slices: torch.Tensor,
|
|
||||||
page_size: int,
|
|
||||||
num_slices_per_block: int) -> torch.Tensor:
|
|
||||||
return kv_cache
|
|
||||||
|
|
||||||
|
|
||||||
# We can move this function to a common utils file if it's also useful for other
|
# We can move this function to a common utils file if it's also useful for other
|
||||||
# hardware.
|
# hardware.
|
||||||
def dtype_bits(dtype: torch.dtype):
|
def dtype_bits(dtype: torch.dtype):
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""A TPU worker class."""
|
"""A TPU worker class."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_xla.core.xla_model as xm
|
|
||||||
import torch_xla.debug.profiler as xp
|
|
||||||
import torch_xla.runtime as xr
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -21,19 +19,27 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.utils import report_usage_stats
|
from vllm.v1.utils import report_usage_stats
|
||||||
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
|
||||||
from vllm.v1.worker.utils import bind_kv_cache
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if not USE_TPU_COMMONS:
|
||||||
|
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.debug.profiler as xp
|
||||||
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
|
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
|
||||||
|
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
class TPUWorker:
|
class TPUWorker:
|
||||||
|
|
||||||
@ -325,9 +331,7 @@ class TPUWorker:
|
|||||||
ensure_kv_transfer_initialized(vllm_config)
|
ensure_kv_transfer_initialized(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
try:
|
if USE_TPU_COMMONS:
|
||||||
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||||
|
|
||||||
TPUWorker = TPUCommonsWorker # type: ignore
|
TPUWorker = TPUCommonsWorker # type: ignore
|
||||||
except ImportError:
|
|
||||||
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
|
||||||
pass
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user