mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:35:34 +08:00
[V1] TPU - Add tensor parallel support via Ray (#13618)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
33f227e16b
commit
cb8bdfade2
@ -42,6 +42,10 @@ def run_test(more_args=None):
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
# TODO: [AlexM] Fix it with new CI/CD tests
|
||||
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU")
|
||||
@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||
# Limit compilation time for TPU V1
|
||||
more_args = "max_num_seqs=64"
|
||||
|
||||
# Add TP test (if provided)
|
||||
if TPU_TP_TEST_STR:
|
||||
more_args += ",{}".format(TPU_TP_TEST_STR)
|
||||
|
||||
run_test(more_args)
|
||||
|
||||
|
||||
|
||||
0
tests/v1/tpu/__init__.py
Normal file
0
tests/v1/tpu/__init__.py
Normal file
54
tests/v1/tpu/test_basic.py
Normal file
54
tests/v1/tpu/test_basic.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A basic correctness check for TPUs
|
||||
|
||||
Run `pytest tests/v1/tpu/test_basic.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
|
||||
MODELS = [
|
||||
# "Qwen/Qwen2-7B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B",
|
||||
# TODO: Add models here as necessary
|
||||
]
|
||||
|
||||
TENSOR_PARALLEL_SIZES = [1]
|
||||
|
||||
# TODO: Enable when CI/CD will have a multi-tpu instance
|
||||
# TENSOR_PARALLEL_SIZES = [1, 4]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This is a basic test for TPU only")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [True])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
||||
def test_models(
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
prompt = "The next numbers of the sequence " + ", ".join(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
enforce_eager=enforce_eager,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_num_seqs=16,
|
||||
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
output = vllm_outputs[0][1]
|
||||
assert "1024" in output
|
||||
@ -73,9 +73,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||
if envs.VLLM_USE_V1:
|
||||
# v1 always uses the compiled DAG and SPMD worker.
|
||||
# V1 uses SPMD worker and compiled DAG
|
||||
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
||||
|
||||
# For TPU, avoid compiling NVIDIA's NCCL
|
||||
if current_platform.is_tpu():
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
|
||||
@ -11,6 +11,7 @@ import vllm.platforms
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import get_ip
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
@ -106,10 +107,15 @@ try:
|
||||
# on a background thread, so we need to reset torch's current
|
||||
# device.
|
||||
# We can remove this API after it is fixed in compiled graph.
|
||||
import torch
|
||||
assert self.worker is not None, "Worker is not initialized"
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
if current_platform.is_tpu():
|
||||
# Not needed
|
||||
pass
|
||||
else:
|
||||
import torch
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
def execute_model_ray(
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
NUM_QUERIES_PER_BLOCK,
|
||||
@ -545,6 +546,7 @@ class TPUModelRunner:
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
# Update cached state
|
||||
self._update_states(scheduler_output)
|
||||
|
||||
@ -96,7 +96,8 @@ class TPUWorker:
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
if self.model_config.seed is not None:
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user