[V1] TPU - Tensor parallel MP support (#15059)

This commit is contained in:
Alexander Matveev 2025-03-19 20:55:18 -04:00 committed by GitHub
parent 0fe5609874
commit cfbca8a2f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 14 deletions

View File

@ -1473,7 +1473,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.") logger.info("Disabling V1 multiprocessing for external launcher.")
ray_only_devices = ["tpu"] ray_only_devices: list[str] = []
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices if (current_platform.device_type in ray_only_devices
and self.world_size > 1): and self.world_size > 1):

View File

@ -6,15 +6,24 @@ from typing import Optional
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .base_device_communicator import DeviceCommunicatorBase from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = get_current_vllm_config(
).parallel_config.distributed_executor_backend == "ray"
logger = init_logger(__name__)
if current_platform.is_tpu(): if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from torch_xla._internal import pjrt from torch_xla._internal import pjrt
if USE_RAY:
from vllm.executor import ray_utils from vllm.executor import ray_utils
@ -33,6 +42,8 @@ class TpuCommunicator(DeviceCommunicatorBase):
global_rank = self.global_rank global_rank = self.global_rank
global_world_size = self.global_world_size global_world_size = self.global_world_size
if USE_RAY:
logger.info("TpuCommunicator initialized with RAY")
# Calculate how many TPU nodes are in the current deployment. This # Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default # is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU # to the number of TPU nodes in the Ray cluster. The number of TPU
@ -46,6 +57,17 @@ class TpuCommunicator(DeviceCommunicatorBase):
local_world_size = global_world_size // num_nodes local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size local_rank = global_rank % local_world_size
else:
logger.info("TpuCommunicator initialized with MP")
# Sanity: Verify we run on a single host
num_hosts = torch_xla.tpu.num_tpu_workers()
assert num_hosts == 1
# Get the current number of TPUs (we have locally)
local_world_size = torch_xla.tpu.num_available_chips()
# Get current rank
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments. # Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU # On GKE, this is needed for libtpu and TPU driver to know which TPU