mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:14:59 +08:00
[V1] TPU - Add tensor parallel support via Ray (#13618)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
33f227e16b
commit
cb8bdfade2
@ -42,6 +42,10 @@ def run_test(more_args=None):
|
|||||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: [AlexM] Fix it with new CI/CD tests
|
||||||
|
TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||||
and not current_platform.is_tpu(),
|
and not current_platform.is_tpu(),
|
||||||
reason="V1 is currently only supported on CUDA and TPU")
|
reason="V1 is currently only supported on CUDA and TPU")
|
||||||
@ -56,6 +60,10 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
|||||||
# Limit compilation time for TPU V1
|
# Limit compilation time for TPU V1
|
||||||
more_args = "max_num_seqs=64"
|
more_args = "max_num_seqs=64"
|
||||||
|
|
||||||
|
# Add TP test (if provided)
|
||||||
|
if TPU_TP_TEST_STR:
|
||||||
|
more_args += ",{}".format(TPU_TP_TEST_STR)
|
||||||
|
|
||||||
run_test(more_args)
|
run_test(more_args)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
tests/v1/tpu/__init__.py
Normal file
0
tests/v1/tpu/__init__.py
Normal file
54
tests/v1/tpu/test_basic.py
Normal file
54
tests/v1/tpu/test_basic.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""A basic correctness check for TPUs
|
||||||
|
|
||||||
|
Run `pytest tests/v1/tpu/test_basic.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from ...conftest import VllmRunner
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
# "Qwen/Qwen2-7B-Instruct",
|
||||||
|
"meta-llama/Llama-3.1-8B",
|
||||||
|
# TODO: Add models here as necessary
|
||||||
|
]
|
||||||
|
|
||||||
|
TENSOR_PARALLEL_SIZES = [1]
|
||||||
|
|
||||||
|
# TODO: Enable when CI/CD will have a multi-tpu instance
|
||||||
|
# TENSOR_PARALLEL_SIZES = [1, 4]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||||
|
reason="This is a basic test for TPU only")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [True])
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
|
||||||
|
def test_models(
|
||||||
|
monkeypatch,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
enforce_eager: bool,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
) -> None:
|
||||||
|
prompt = "The next numbers of the sequence " + ", ".join(
|
||||||
|
str(i) for i in range(1024)) + " are:"
|
||||||
|
example_prompts = [prompt]
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
model,
|
||||||
|
max_model_len=8192,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
max_num_seqs=16,
|
||||||
|
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens)
|
||||||
|
output = vllm_outputs[0][1]
|
||||||
|
assert "1024" in output
|
||||||
@ -73,9 +73,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
# v1 always uses the compiled DAG and SPMD worker.
|
# V1 uses SPMD worker and compiled DAG
|
||||||
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
||||||
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
||||||
|
|
||||||
|
# For TPU, avoid compiling NVIDIA's NCCL
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0"
|
||||||
|
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
# which optimizes the control plane overhead.
|
# which optimizes the control plane overhead.
|
||||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import vllm.platforms
|
|||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.utils import get_ip
|
from vllm.utils import get_ip
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
@ -106,10 +107,15 @@ try:
|
|||||||
# on a background thread, so we need to reset torch's current
|
# on a background thread, so we need to reset torch's current
|
||||||
# device.
|
# device.
|
||||||
# We can remove this API after it is fixed in compiled graph.
|
# We can remove this API after it is fixed in compiled graph.
|
||||||
import torch
|
|
||||||
assert self.worker is not None, "Worker is not initialized"
|
assert self.worker is not None, "Worker is not initialized"
|
||||||
if not self.compiled_dag_cuda_device_set:
|
if not self.compiled_dag_cuda_device_set:
|
||||||
torch.cuda.set_device(self.worker.device)
|
if current_platform.is_tpu():
|
||||||
|
# Not needed
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
torch.cuda.set_device(self.worker.device)
|
||||||
|
|
||||||
self.compiled_dag_cuda_device_set = True
|
self.compiled_dag_cuda_device_set = True
|
||||||
|
|
||||||
def execute_model_ray(
|
def execute_model_ray(
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||||
NUM_QUERIES_PER_BLOCK,
|
NUM_QUERIES_PER_BLOCK,
|
||||||
@ -545,6 +546,7 @@ class TPUModelRunner:
|
|||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
# Update cached state
|
# Update cached state
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
|
|||||||
@ -96,7 +96,8 @@ class TPUWorker:
|
|||||||
|
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
xm.set_rng_state(self.model_config.seed, self.device)
|
if self.model_config.seed is not None:
|
||||||
|
xm.set_rng_state(self.model_config.seed, self.device)
|
||||||
|
|
||||||
# Increase the cache size limit, which is the maximum number of
|
# Increase the cache size limit, which is the maximum number of
|
||||||
# dynamo graphs that can be compiled.
|
# dynamo graphs that can be compiled.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user