[core] allow callable in collective_rpc (#12151)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-17 20:47:01 +08:00 committed by GitHub
parent d4e6194570
commit 87a0c076af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 147 additions and 50 deletions

View File

@ -107,7 +107,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
commands: commands:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
@ -466,7 +466,9 @@ steps:
- vllm/worker/worker_base.py - vllm/worker/worker_base.py
- vllm/worker/worker.py - vllm/worker/worker.py
- vllm/worker/model_runner.py - vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
commands: commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py

View File

@ -1,6 +1,6 @@
import asyncio import asyncio
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import pytest import pytest
@ -18,7 +18,7 @@ class Mock:
class CustomUniExecutor(UniProcExecutor): class CustomUniExecutor(UniProcExecutor):
def collective_rpc(self, def collective_rpc(self,
method: str, method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:

View File

@ -0,0 +1,36 @@
import pytest
from vllm import LLM
from ...utils import fork_new_process_for_each_test
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("backend", ["mp", "ray"])
@fork_new_process_for_each_test
def test_collective_rpc(tp_size, backend):
if tp_size == 1 and backend == "ray":
pytest.skip("Skip duplicate test case")
if tp_size == 1:
backend = None
# intentionally define the method and class in the test function,
# to test if they can be serialized and sent to the workers
def echo_rank(self):
return self.rank
from vllm.worker.worker import Worker
class MyWorker(Worker):
def echo_rank(self):
return self.rank
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
load_format="dummy",
tensor_parallel_size=tp_size,
distributed_executor_backend=backend,
worker_cls=MyWorker)
for method in ["echo_rank", echo_rank]:
assert llm.collective_rpc(method) == list(range(tp_size))

View File

@ -5,10 +5,10 @@ from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload from typing import Set, Tuple, Type, Union, cast, overload
import torch import torch
from typing_extensions import TypeVar, deprecated from typing_extensions import TypeVar, deprecated
@ -1816,6 +1816,17 @@ class LLMEngine:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.model_executor.stop_profile() self.model_executor.stop_profile()
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer: if self.tokenizer:
self.tokenizer.check_health() self.tokenizer.check_health()

View File

@ -1,8 +1,8 @@
import itertools import itertools
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Union, cast, overload) Tuple, Type, Union, cast, overload)
import cloudpickle import cloudpickle
from tqdm import tqdm from tqdm import tqdm
@ -464,7 +464,7 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput) return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self, def collective_rpc(self,
method: str, method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:
@ -476,9 +476,13 @@ class LLM:
Then, users can call the new methods through this API. Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages, It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data. and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
""" """
return self.llm_engine.model_executor.collective_rpc( return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
method, timeout, args, kwargs)
def beam_search( def beam_search(
self, self,

View File

@ -1,6 +1,7 @@
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
@ -47,7 +48,7 @@ class ExecutorBase(ABC):
@abstractmethod @abstractmethod
def collective_rpc(self, def collective_rpc(self,
method: str, method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:
@ -260,7 +261,7 @@ class DistributedExecutorBase(ExecutorBase):
raise NotImplementedError raise NotImplementedError
def collective_rpc(self, def collective_rpc(self,
method: str, method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:
@ -269,7 +270,7 @@ class DistributedExecutorBase(ExecutorBase):
@abstractmethod @abstractmethod
def _run_workers( def _run_workers(
self, self,
method: str, method: Union[str, Callable],
*args, *args,
async_run_tensor_parallel_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,

View File

@ -1,5 +1,7 @@
import asyncio import asyncio
from typing import Any, List, Optional from typing import Any, Callable, List, Optional, Union
import cloudpickle
from vllm.executor.executor_base import DistributedExecutorBase from vllm.executor.executor_base import DistributedExecutorBase
from vllm.executor.multiproc_worker_utils import ( from vllm.executor.multiproc_worker_utils import (
@ -9,7 +11,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.utils import (_run_task_with_lock, get_distributed_init_method, from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
get_ip, get_open_port, make_async) get_ip, get_open_port, make_async, run_method)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -107,7 +109,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
def _run_workers( def _run_workers(
self, self,
method: str, method: Union[str, Callable],
*args, *args,
async_run_tensor_parallel_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
@ -121,6 +123,11 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
It will also be run asynchronously and return a list of futures It will also be run asynchronously and return a list of futures
rather than blocking on the results. rather than blocking on the results.
""" """
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
@ -129,18 +136,18 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
if async_run_tensor_parallel_workers_only: if async_run_tensor_parallel_workers_only:
# Run only non-driver workers and just return futures. # Run only non-driver workers and just return futures.
return [ return [
worker.execute_method(method, *args, **kwargs) worker.execute_method(sent_method, *args, **kwargs)
for worker in self.non_driver_workers for worker in self.non_driver_workers
] ]
# Start all remote workers first. # Start all remote workers first.
worker_outputs = [ worker_outputs = [
worker.execute_method(method, *args, **kwargs) worker.execute_method(sent_method, *args, **kwargs)
for worker in self.workers for worker in self.workers
] ]
driver_worker_method = getattr(self.driver_worker, method) driver_worker_output = run_method(self.driver_worker, sent_method,
driver_worker_output = driver_worker_method(*args, **kwargs) args, kwargs)
# Get the results of the workers. # Get the results of the workers.
return [driver_worker_output return [driver_worker_output

View File

@ -15,7 +15,7 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context from vllm.utils import _check_multiproc_method, get_mp_context, run_method
if HAS_TRITON: if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.triton_utils import maybe_set_triton_cache_manager
@ -169,7 +169,7 @@ class ProcessWorkerWrapper:
self.process.start() self.process.start()
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
method: str, args, kwargs): method: Union[str, bytes], args, kwargs):
task_id = uuid.uuid4() task_id = uuid.uuid4()
self.tasks[task_id] = future self.tasks[task_id] = future
try: try:
@ -180,12 +180,13 @@ class ProcessWorkerWrapper:
del self.tasks[task_id] del self.tasks[task_id]
raise ChildProcessError("worker died") from e raise ChildProcessError("worker died") from e
def execute_method(self, method: str, *args, **kwargs): def execute_method(self, method: Union[str, bytes], *args, **kwargs):
future: ResultFuture = ResultFuture() future: ResultFuture = ResultFuture()
self._enqueue_task(future, method, args, kwargs) self._enqueue_task(future, method, args, kwargs)
return future return future
async def execute_method_async(self, method: str, *args, **kwargs): async def execute_method_async(self, method: Union[str, bytes], *args,
**kwargs):
future = asyncio.get_running_loop().create_future() future = asyncio.get_running_loop().create_future()
self._enqueue_task(future, method, args, kwargs) self._enqueue_task(future, method, args, kwargs)
return await future return await future
@ -230,8 +231,7 @@ def _run_worker_process(
exception = None exception = None
task_id, method, args, kwargs = items task_id, method, args, kwargs = items
try: try:
executor = getattr(worker, method) output = run_method(worker, method, args, kwargs)
output = executor(*args, **kwargs)
except SystemExit: except SystemExit:
raise raise
except KeyboardInterrupt: except KeyboardInterrupt:

View File

@ -2,8 +2,9 @@ import asyncio
import os import os
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
import cloudpickle
import msgspec import msgspec
import vllm.envs as envs import vllm.envs as envs
@ -410,7 +411,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _run_workers( def _run_workers(
self, self,
method: str, method: Union[str, Callable],
*args, *args,
async_run_tensor_parallel_workers_only: bool = False, async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
@ -426,6 +427,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
rather than blocking on the results. rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs - args/kwargs: All workers share the same args/kwargs
""" """
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker: if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, ( assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for " "async_run_tensor_parallel_workers_only is not supported for "
@ -440,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
if async_run_tensor_parallel_workers_only: if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers ray_workers = self.non_driver_workers
ray_worker_outputs = [ ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs) worker.execute_method.remote(sent_method, *args, **kwargs)
for worker in ray_workers for worker in ray_workers
] ]
@ -455,7 +461,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
if not self.use_ray_spmd_worker: if not self.use_ray_spmd_worker:
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
driver_worker_output = [ driver_worker_output = [
self.driver_worker.execute_method(method, *args, **kwargs) self.driver_worker.execute_method(sent_method, *args, **kwargs)
] ]
# Get the results of the ray workers. # Get the results of the ray workers.

View File

@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -7,7 +7,8 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
run_method)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -39,18 +40,13 @@ class UniProcExecutor(ExecutorBase):
self.collective_rpc("load_model") self.collective_rpc("load_model")
def collective_rpc(self, def collective_rpc(self,
method: str, method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
try: answer = run_method(self.driver_worker, method, args, kwargs)
func = getattr(self.driver_worker, method)
except AttributeError:
raise NotImplementedError(f"Method {method} is not implemented.") \
from None
answer = func(*args, **kwargs)
return [answer] return [answer]
def check_health(self) -> None: def check_health(self) -> None:

View File

@ -36,6 +36,7 @@ from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
overload) overload)
from uuid import uuid4 from uuid import uuid4
import cloudpickle
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import psutil import psutil
@ -2166,3 +2167,25 @@ def bind_kv_cache(
assert len(forward_ctx.kv_cache) == len(kv_cache) assert len(forward_ctx.kv_cache) == len(kv_cache)
for ve, ve_kv_cache in enumerate(kv_cache): for ve, ve_kv_cache in enumerate(kv_cache):
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
kwargs: Dict[str, Any]) -> Any:
"""
Run a method of an object with the given arguments and keyword arguments.
If the method is string, it will be converted to a method using getattr.
If the method is serialized bytes and will be deserialized using
cloudpickle.
If the method is a callable, it will be called directly.
"""
if isinstance(method, bytes):
func = partial(cloudpickle.loads(method), obj)
elif isinstance(method, str):
try:
func = getattr(obj, method)
except AttributeError:
raise NotImplementedError(f"Method {method!r} is not"
" implemented.") from None
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)

View File

@ -6,9 +6,11 @@ import time
import weakref import weakref
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial
from multiprocessing.process import BaseProcess from multiprocessing.process import BaseProcess
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import cloudpickle
import psutil import psutil
import zmq import zmq
@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
return kv_cache_specs[0] return kv_cache_specs[0]
def collective_rpc(self, def collective_rpc(self,
method: str, method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,
args: Tuple = (), args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]: kwargs: Optional[Dict] = None) -> List[Any]:
@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
kwargs = kwargs or {} kwargs = kwargs or {}
try: try:
self.rpc_broadcast_mq.enqueue((method, args, kwargs)) if isinstance(method, str):
send_method = method
else:
send_method = cloudpickle.dumps(
method, protocol=pickle.HIGHEST_PROTOCOL)
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
responses = [None] * self.world_size responses = [None] * self.world_size
for w in self.workers: for w in self.workers:
@ -408,7 +415,11 @@ class WorkerProc:
method, args, kwargs = self.rpc_broadcast_mq.dequeue() method, args, kwargs = self.rpc_broadcast_mq.dequeue()
try: try:
output = getattr(self.worker, method)(*args, **kwargs) if isinstance(method, str):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)
except Exception as e: except Exception as e:
self.worker_response_mq.enqueue( self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.FAILURE, e)) (WorkerProc.ResponseStatus.FAILURE, e))

View File

@ -14,7 +14,8 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables) resolve_obj_by_qualname, run_method,
update_environment_variables)
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
@ -539,17 +540,16 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
assert self.worker is not None assert self.worker is not None
def execute_method(self, method: str, *args, **kwargs): def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try: try:
target = self if self.worker is None else self.worker target = self if self.worker is None else self.worker
executor = getattr(target, method) return run_method(target, method, args, kwargs)
return executor(*args, **kwargs)
except Exception as e: except Exception as e:
# if the driver worker also execute methods, # if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray # exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455 # see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error # print the error and inform the user to solve the error
msg = (f"Error executing method {method}. " msg = (f"Error executing method {method!r}. "
"This might cause deadlock in distributed execution.") "This might cause deadlock in distributed execution.")
logger.exception(msg) logger.exception(msg)
raise e raise e