[v1] torchrun compatibility (#13642)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-23 22:47:24 +08:00 committed by GitHub
parent 9bebc9512f
commit eb24dc4a45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 67 additions and 24 deletions

View File

@ -503,6 +503,7 @@ steps:
- entrypoints/llm/test_collective_rpc.py
commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py

View File

@ -48,6 +48,12 @@ test_consistent_across_ranks(
test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
model.parameters())
test_consistent_across_ranks(len(params))
# all ranks should have the same outputs
for output in outputs:
prompt = output.prompt

View File

@ -5,6 +5,7 @@ import threading
import time
import uuid
from concurrent.futures import Future
from typing import List
import pytest
from transformers import AutoTokenizer
@ -211,8 +212,9 @@ def test_engine_core_concurrent_batches(monkeypatch):
class DummyExecutor(UniProcExecutor):
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
super().initialize(kv_cache_config)
def initialize_from_config(
self, kv_cache_configs: List[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1)

View File

@ -1407,6 +1407,11 @@ class ParallelConfig:
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
ray_only_devices = ["tpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices

View File

@ -541,7 +541,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# and the TP group executes in SPMD fashion.
if self.use_v1:
outputs = [
worker.execute_model.
worker.execute_model_ray.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]

View File

@ -112,10 +112,12 @@ try:
torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True
def execute_model(
def execute_model_ray(
self,
scheduler_output: "SchedulerOutput",
) -> "ModelRunnerOutput":
# this method is used to compile ray CG,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple):

View File

@ -93,9 +93,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
("ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
"does not support delay_factor in scheduling")
assert not envs.VLLM_USE_V1, \
("V1 architecture cannot guarantee deterministic execution, "
"so it is not supported in ExecutorWithExternalLauncher.")
if envs.VLLM_USE_V1:
assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
("To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0)
# engines are launched in torchrun-compatible launchers

View File

@ -110,7 +110,7 @@ class EngineCore:
num_cpu_blocks = 0
# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_configs)
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "

View File

@ -4,10 +4,10 @@ from typing import Dict, List, Mapping, Optional, Type, Union
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -44,6 +44,7 @@ class LLMEngine:
use_cached_outputs: bool = False,
multiprocess_mode: bool = False,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@ -83,6 +84,10 @@ class LLMEngine:
log_stats=False, # FIXME: implement
)
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
@classmethod
def from_engine_args(
cls,
@ -97,7 +102,7 @@ class LLMEngine:
vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING:
if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True

View File

@ -3,6 +3,9 @@
from concurrent.futures import Future
from typing import List, Type, Union
import torch
import torch.distributed as dist
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
@ -49,12 +52,14 @@ class Executor(ExecutorBase):
f"{distributed_executor_backend}")
return executor_class
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes
@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass
def determine_available_memory(self) -> int: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()

View File

@ -216,9 +216,10 @@ class WorkerProc:
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper.worker
self.worker = wrapper
pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
@ -239,7 +240,7 @@ class WorkerProc:
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
wrapper.init_device()
self.worker.init_device()
self.worker.load_model()
@staticmethod

View File

@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
@ -185,9 +185,8 @@ class Worker(WorkerBase):
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
@ -225,7 +224,7 @@ class Worker(WorkerBase):
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):
if self.profiler is None:

View File

@ -36,6 +36,7 @@ class TPUWorker:
distributed_init_method: str,
is_driver_worker: bool = False,
):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
@ -151,7 +152,7 @@ class TPUWorker:
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None
return output if self.is_driver_worker else None
def load_model(self) -> None:
self.model_runner.load_model()
@ -170,9 +171,8 @@ class TPUWorker:
def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None:

View File

@ -567,6 +567,10 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs)
assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self):
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
@ -574,8 +578,11 @@ class WorkerWrapperBase:
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
target = self if self.worker is None else self.worker
return run_method(target, method, args, kwargs)
# method resolution order:
# if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray