mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:55:00 +08:00
[Hardware][TPU] Optionally import for TPU backend (#18269)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com> Co-authored-by: Carol Zheng <cazheng@google.com> Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com> Co-authored-by: Hongmin Fan <fanhongmin@google.com>
This commit is contained in:
parent
3e0d435027
commit
48ac2bed5b
@ -91,3 +91,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
|||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||||
return xm.all_gather(input_, dim=dim)
|
return xm.all_gather(input_, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tpu_commons.distributed.device_communicators import (
|
||||||
|
TpuCommunicator as TpuCommonsCommunicator)
|
||||||
|
TpuCommunicator = TpuCommonsCommunicator # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
|
||||||
|
pass
|
||||||
|
|||||||
@ -194,3 +194,11 @@ class TpuPlatform(Platform):
|
|||||||
if params.sampling_type == SamplingType.RANDOM_SEED:
|
if params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Torch XLA does not support per-request seed.")
|
"Torch XLA does not support per-request seed.")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
|
||||||
|
TpuPlatform = TpuCommonsPlatform # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
|
||||||
|
pass
|
||||||
|
|||||||
@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment(
|
|||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size,
|
parallel_config.pipeline_parallel_size,
|
||||||
parallel_config.enable_expert_parallel)
|
parallel_config.enable_expert_parallel)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
|
||||||
|
TPUWorker = TPUCommonsWorker # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
|
||||||
|
pass
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user