[core] allow callable in collective_rpc (#12151)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-17 20:47:01 +08:00 committed by GitHub
parent d4e6194570
commit 87a0c076af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 147 additions and 50 deletions

View File

@ -107,7 +107,7 @@ steps:
source_file_dependencies:
- vllm/
commands:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
@ -466,7 +466,9 @@ steps:
- vllm/worker/worker_base.py
- vllm/worker/worker.py
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py

View File

@ -1,6 +1,6 @@
import asyncio
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pytest
@ -18,7 +18,7 @@ class Mock:
class CustomUniExecutor(UniProcExecutor):
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:

View File

@ -0,0 +1,36 @@
import pytest
from vllm import LLM
from ...utils import fork_new_process_for_each_test
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("backend", ["mp", "ray"])
@fork_new_process_for_each_test
def test_collective_rpc(tp_size, backend):
if tp_size == 1 and backend == "ray":
pytest.skip("Skip duplicate test case")
if tp_size == 1:
backend = None
# intentionally define the method and class in the test function,
# to test if they can be serialized and sent to the workers
def echo_rank(self):
return self.rank
from vllm.worker.worker import Worker
class MyWorker(Worker):
def echo_rank(self):
return self.rank
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=tp_size,
distributed_executor_backend=backend,
worker_cls=MyWorker)
for method in ["echo_rank", echo_rank]:
assert llm.collective_rpc(method) == list(range(tp_size))

View File

@ -5,10 +5,10 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
from typing import Set, Tuple, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar, deprecated
@ -1816,6 +1816,17 @@ class LLMEngine:
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()

View File

@ -1,8 +1,8 @@
import itertools
import warnings
from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
Union, cast, overload)
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
import cloudpickle
from tqdm import tqdm
@ -464,7 +464,7 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
@ -476,9 +476,13 @@ class LLM:
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
"""
return self.llm_engine.model_executor.collective_rpc(
method, timeout, args, kwargs)
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
def beam_search(
self,

View File

@ -1,6 +1,7 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -47,7 +48,7 @@ class ExecutorBase(ABC):
@abstractmethod
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
@ -260,7 +261,7 @@ class DistributedExecutorBase(ExecutorBase):
raise NotImplementedError
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
@ -269,7 +270,7 @@ class DistributedExecutorBase(ExecutorBase):
@abstractmethod
def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,

View File

@ -1,5 +1,7 @@
import asyncio
from typing import Any, List, Optional
from typing import Any, Callable, List, Optional, Union
import cloudpickle
from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.multiproc_worker_utils import (
@ -9,7 +11,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async)
get_ip, get_open_port, make_async, run_method)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -107,7 +109,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
@ -121,6 +123,11 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
It will also be run asynchronously and return a list of futures
rather than blocking on the results.
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if max_concurrent_workers:
raise NotImplementedError(
@ -129,18 +136,18 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures.
return [
worker.execute_method(method, *args, **kwargs)
worker.execute_method(sent_method, *args, **kwargs)
for worker in self.non_driver_workers
]
# Start all remote workers first.
worker_outputs = [
worker.execute_method(method, *args, **kwargs)
worker.execute_method(sent_method, *args, **kwargs)
for worker in self.workers
]
driver_worker_method = getattr(self.driver_worker, method)
driver_worker_output = driver_worker_method(*args, **kwargs)
driver_worker_output = run_method(self.driver_worker, sent_method,
args, kwargs)
# Get the results of the workers.
return [driver_worker_output

View File

@ -15,7 +15,7 @@ import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
@ -169,7 +169,7 @@ class ProcessWorkerWrapper:
self.process.start()
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs):
method: Union[str, bytes], args, kwargs):
task_id = uuid.uuid4()
self.tasks[task_id] = future
try:
@ -180,12 +180,13 @@ class ProcessWorkerWrapper:
del self.tasks[task_id]
raise ChildProcessError("worker died") from e
def execute_method(self, method: str, *args, **kwargs):
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs)
return future
async def execute_method_async(self, method: str, *args, **kwargs):
async def execute_method_async(self, method: Union[str, bytes], *args,
**kwargs):
future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs)
return await future
@ -230,8 +231,7 @@ def _run_worker_process(
exception = None
task_id, method, args, kwargs = items
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
output = run_method(worker, method, args, kwargs)
except SystemExit:
raise
except KeyboardInterrupt:

View File

@ -2,8 +2,9 @@ import asyncio
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import cloudpickle
import msgspec
import vllm.envs as envs
@ -410,7 +411,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _run_workers(
self,
method: str,
method: Union[str, Callable],
*args,
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
@ -426,6 +427,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
@ -440,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers
]
@ -455,7 +461,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers.
driver_worker_output = [
self.driver_worker.execute_method(method, *args, **kwargs)
self.driver_worker.execute_method(sent_method, *args, **kwargs)
]
# Get the results of the ray workers.

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@ -7,7 +7,8 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -39,18 +40,13 @@ class UniProcExecutor(ExecutorBase):
self.collective_rpc("load_model")
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None:
kwargs = {}
try:
func = getattr(self.driver_worker, method)
except AttributeError:
raise NotImplementedError(f"Method {method} is not implemented.") \
from None
answer = func(*args, **kwargs)
answer = run_method(self.driver_worker, method, args, kwargs)
return [answer]
def check_health(self) -> None:

View File

@ -36,6 +36,7 @@ from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
overload)
from uuid import uuid4
import cloudpickle
import numpy as np
import numpy.typing as npt
import psutil
@ -2166,3 +2167,25 @@ def bind_kv_cache(
assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
kwargs: Dict[str, Any]) -> Any:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if isinstance(method, bytes):
func = partial(cloudpickle.loads(method), obj)
elif isinstance(method, str):
try:
func = getattr(obj, method)
except AttributeError:
raise NotImplementedError(f"Method {method!r} is not"
" implemented.") from None
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)

View File

@ -6,9 +6,11 @@ import time
import weakref
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from multiprocessing.process import BaseProcess
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cloudpickle
import psutil
import zmq
@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
return kv_cache_specs[0]
def collective_rpc(self,
method: str,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
kwargs = kwargs or {}
try:
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
responses = [None] * self.world_size
for w in self.workers:
@ -408,7 +415,11 @@ class WorkerProc:
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
try:
output = getattr(self.worker, method)(*args, **kwargs)
if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e))

View File

@ -14,7 +14,8 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables)
resolve_obj_by_qualname, run_method,
update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
@ -539,17 +540,16 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs)
assert self.worker is not None
def execute_method(self, method: str, *args, **kwargs):
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
target = self if self.worker is None else self.worker
executor = getattr(target, method)
return executor(*args, **kwargs)
return run_method(target, method, args, kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
msg = (f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e