[v1] torchrun compatibility (#13642)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-23 22:47:24 +08:00 committed by GitHub
parent 9bebc9512f
commit eb24dc4a45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 67 additions and 24 deletions

View File

@ -503,6 +503,7 @@ steps:
- entrypoints/llm/test_collective_rpc.py - entrypoints/llm/test_collective_rpc.py
commands: commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_collective_rpc.py
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py

View File

@ -48,6 +48,12 @@ test_consistent_across_ranks(
test_consistent_across_ranks( test_consistent_across_ranks(
llm.llm_engine.vllm_config.cache_config.num_gpu_blocks) llm.llm_engine.vllm_config.cache_config.num_gpu_blocks)
# make sure we can access the model parameters from the calling process
# of the `LLM` instance.
params = list(llm.llm_engine.model_executor.driver_worker.worker.model_runner.
model.parameters())
test_consistent_across_ranks(len(params))
# all ranks should have the same outputs # all ranks should have the same outputs
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt

View File

@ -5,6 +5,7 @@ import threading
import time import time
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
from typing import List
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -211,8 +212,9 @@ def test_engine_core_concurrent_batches(monkeypatch):
class DummyExecutor(UniProcExecutor): class DummyExecutor(UniProcExecutor):
def initialize(self, kv_cache_config: KVCacheConfig) -> None: def initialize_from_config(
super().initialize(kv_cache_config) self, kv_cache_configs: List[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time # This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1) self.semaphore = threading.Semaphore(1)

View File

@ -1407,6 +1407,11 @@ class ParallelConfig:
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
ray_only_devices = ["tpu"] ray_only_devices = ["tpu"]
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices if (current_platform.device_type in ray_only_devices

View File

@ -541,7 +541,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# and the TP group executes in SPMD fashion. # and the TP group executes in SPMD fashion.
if self.use_v1: if self.use_v1:
outputs = [ outputs = [
worker.execute_model. worker.execute_model_ray.
bind( # type: ignore[attr-defined] bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group) outputs[i]) for i, worker in enumerate(tp_group)
] ]

View File

@ -112,10 +112,12 @@ try:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
def execute_model( def execute_model_ray(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> "ModelRunnerOutput": ) -> "ModelRunnerOutput":
# this method is used to compile ray CG,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary() self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized" assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple): if isinstance(scheduler_output, tuple):

View File

@ -93,9 +93,10 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
("ExecutorWithExternalLauncher needs deterministic " ("ExecutorWithExternalLauncher needs deterministic "
"execution, so it" "execution, so it"
"does not support delay_factor in scheduling") "does not support delay_factor in scheduling")
assert not envs.VLLM_USE_V1, \ if envs.VLLM_USE_V1:
("V1 architecture cannot guarantee deterministic execution, " assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
"so it is not supported in ExecutorWithExternalLauncher.") ("To get deterministic execution in V1, "
"please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
rpc_rank=0) rpc_rank=0)
# engines are launched in torchrun-compatible launchers # engines are launched in torchrun-compatible launchers

View File

@ -110,7 +110,7 @@ class EngineCore:
num_cpu_blocks = 0 num_cpu_blocks = 0
# Initialize kv cache and warmup the execution # Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_configs) self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, " logger.info(("init engine (profile, create kv cache, "

View File

@ -4,10 +4,10 @@ from typing import Dict, List, Mapping, Optional, Type, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -44,6 +44,7 @@ class LLMEngine:
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
multiprocess_mode: bool = False, multiprocess_mode: bool = False,
) -> None: ) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
@ -83,6 +84,10 @@ class LLMEngine:
log_stats=False, # FIXME: implement log_stats=False, # FIXME: implement
) )
if not multiprocess_mode:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
@classmethod @classmethod
def from_engine_args( def from_engine_args(
cls, cls,
@ -97,7 +102,7 @@ class LLMEngine:
vllm_config = engine_args.create_engine_config(usage_context) vllm_config = engine_args.create_engine_config(usage_context)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
if VLLM_ENABLE_V1_MULTIPROCESSING: if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
logger.debug("Enabling multiprocessing for LLMEngine.") logger.debug("Enabling multiprocessing for LLMEngine.")
enable_multiprocessing = True enable_multiprocessing = True

View File

@ -3,6 +3,9 @@
from concurrent.futures import Future from concurrent.futures import Future
from typing import List, Type, Union from typing import List, Type, Union
import torch
import torch.distributed as dist
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa
@ -49,12 +52,14 @@ class Executor(ExecutorBase):
f"{distributed_executor_backend}") f"{distributed_executor_backend}")
return executor_class return executor_class
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None: def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None:
""" """
Initialize the KV caches and begin the model execution loop of the Initialize the KV caches and begin the model execution loop of the
underlying workers. underlying workers.
""" """
self.collective_rpc("initialize_cache", args=(kv_cache_configs, )) self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model") self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes def determine_available_memory(self) -> int: # in bytes
@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor): class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass
def determine_available_memory(self) -> int: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()

View File

@ -216,9 +216,10 @@ class WorkerProc:
"local_rank": local_rank, "local_rank": local_rank,
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0,
} }
wrapper.init_worker(all_kwargs) wrapper.init_worker(all_kwargs)
self.worker = wrapper.worker self.worker = wrapper
pid = os.getpid() pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
@ -239,7 +240,7 @@ class WorkerProc:
ready_socket.send_string(WorkerProc.READY_STR) ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload) ready_socket.send(payload)
wrapper.init_device() self.worker.init_device()
self.worker.load_model() self.worker.load_model()
@staticmethod @staticmethod

View File

@ -2,7 +2,7 @@
"""A GPU worker class.""" """A GPU worker class."""
import gc import gc
import os import os
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
import torch.distributed import torch.distributed
@ -185,9 +185,8 @@ class Worker(WorkerBase):
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache") context = allocator.use_memory_pool(tag="kv_cache")
@ -225,7 +224,7 @@ class Worker(WorkerBase):
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None return output if self.is_driver_worker else None
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.profiler is None: if self.profiler is None:

View File

@ -36,6 +36,7 @@ class TPUWorker:
distributed_init_method: str, distributed_init_method: str,
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
self.is_driver_worker = is_driver_worker
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
@ -151,7 +152,7 @@ class TPUWorker:
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]: ) -> Optional[ModelRunnerOutput]:
output = self.model_runner.execute_model(scheduler_output) output = self.model_runner.execute_model(scheduler_output)
return output if self.rank == 0 else None return output if self.is_driver_worker else None
def load_model(self) -> None: def load_model(self) -> None:
self.model_runner.load_model() self.model_runner.load_model()
@ -170,9 +171,8 @@ class TPUWorker:
def get_kv_cache_spec(self) -> KVCacheSpec: def get_kv_cache_spec(self) -> KVCacheSpec:
return self.model_runner.get_kv_cache_spec() return self.model_runner.get_kv_cache_spec()
def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config.""" """Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config = kv_cache_configs[self.rank]
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
def check_health(self) -> None: def check_health(self) -> None:

View File

@ -567,6 +567,10 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
assert self.worker is not None assert self.worker is not None
def initialize_from_config(self, kv_cache_configs: List[Any]) -> None:
kv_cache_config = kv_cache_configs[self.rpc_rank]
self.worker.initialize_from_config(kv_cache_config) # type: ignore
def init_device(self): def init_device(self):
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization # To make vLLM config available during device initialization
@ -574,8 +578,11 @@ class WorkerWrapperBase:
def execute_method(self, method: Union[str, bytes], *args, **kwargs): def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try: try:
target = self if self.worker is None else self.worker # method resolution order:
return run_method(target, method, args, kwargs) # if a method is defined in this class, it will be called directly.
# otherwise, since we define `__getattr__` and redirect attribute
# query to `self.worker`, the method will be called on the worker.
return run_method(self, method, args, kwargs)
except Exception as e: except Exception as e:
# if the driver worker also execute methods, # if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray # exceptions in the rest worker may cause deadlock in rpc like ray