[V1] TPU - Add tensor parallel support via Ray (#13618)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-03-08 08:19:38 -05:00 committed by GitHub
parent 33f227e16b
commit cb8bdfade2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 80 additions and 4 deletions

View File

@ -42,6 +42,10 @@ def run_test(more_args=None):
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
# TODO: [AlexM] Fix it with new CI/CD tests
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
@pytest.mark.skipif(not current_platform.is_cuda() @pytest.mark.skipif(not current_platform.is_cuda()
and not current_platform.is_tpu(), and not current_platform.is_tpu(),
reason="V1 is currently only supported on CUDA and TPU") reason="V1 is currently only supported on CUDA and TPU")
@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
# Limit compilation time for TPU V1 # Limit compilation time for TPU V1
more_args = "max_num_seqs=64" more_args = "max_num_seqs=64"
# Add TP test (if provided)
if TPU_TP_TEST_STR:
more_args += ",{}".format(TPU_TP_TEST_STR)
run_test(more_args) run_test(more_args)

0
tests/v1/tpu/__init__.py Normal file
View File

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
"""A basic correctness check for TPUs
Run `pytest tests/v1/tpu/test_basic.py`.
"""
import pytest
from vllm.platforms import current_platform
from ...conftest import VllmRunner
MODELS = [
# "Qwen/Qwen2-7B-Instruct",
"meta-llama/Llama-3.1-8B",
# TODO: Add models here as necessary
]
TENSOR_PARALLEL_SIZES = [1]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
def test_models(
monkeypatch,
model: str,
max_tokens: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
prompt = "The next numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with VllmRunner(
model,
max_model_len=8192,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
max_num_seqs=16,
tensor_parallel_size=tensor_parallel_size) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts,
max_tokens)
output = vllm_outputs[0][1]
assert "1024" in output

View File

@ -73,9 +73,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
# v1 always uses the compiled DAG and SPMD worker. # V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
# For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.

View File

@ -11,6 +11,7 @@ import vllm.platforms
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip from vllm.utils import get_ip
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
@ -106,10 +107,15 @@ try:
# on a background thread, so we need to reset torch's current # on a background thread, so we need to reset torch's current
# device. # device.
# We can remove this API after it is fixed in compiled graph. # We can remove this API after it is fixed in compiled graph.
import torch
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
if not self.compiled_dag_cuda_device_set: if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device) if current_platform.is_tpu():
# Not needed
pass
else:
import torch
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
def execute_model_ray( def execute_model_ray(

View File

@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK, NUM_QUERIES_PER_BLOCK,
@ -545,6 +546,7 @@ class TPUModelRunner:
def execute_model( def execute_model(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
# Update cached state # Update cached state
self._update_states(scheduler_output) self._update_states(scheduler_output)

View File

@ -96,7 +96,8 @@ class TPUWorker:
# Set random seed. # Set random seed.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
xm.set_rng_state(self.model_config.seed, self.device) if self.model_config.seed is not None:
xm.set_rng_state(self.model_config.seed, self.device)
# Increase the cache size limit, which is the maximum number of # Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled. # dynamo graphs that can be compiled.