mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[core] allow callable in collective_rpc (#12151)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d4e6194570
commit
87a0c076af
@ -107,7 +107,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
@ -466,7 +466,9 @@ steps:
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -18,7 +18,7 @@ class Mock:
|
||||
class CustomUniExecutor(UniProcExecutor):
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
|
||||
36
tests/entrypoints/llm/test_collective_rpc.py
Normal file
36
tests/entrypoints/llm/test_collective_rpc.py
Normal file
@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("backend", ["mp", "ray"])
|
||||
@fork_new_process_for_each_test
|
||||
def test_collective_rpc(tp_size, backend):
|
||||
if tp_size == 1 and backend == "ray":
|
||||
pytest.skip("Skip duplicate test case")
|
||||
if tp_size == 1:
|
||||
backend = None
|
||||
|
||||
# intentionally define the method and class in the test function,
|
||||
# to test if they can be serialized and sent to the workers
|
||||
def echo_rank(self):
|
||||
return self.rank
|
||||
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
class MyWorker(Worker):
|
||||
|
||||
def echo_rank(self):
|
||||
return self.rank
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
tensor_parallel_size=tp_size,
|
||||
distributed_executor_backend=backend,
|
||||
worker_cls=MyWorker)
|
||||
for method in ["echo_rank", echo_rank]:
|
||||
assert llm.collective_rpc(method) == list(range(tp_size))
|
||||
@ -5,10 +5,10 @@ from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||
List, Mapping, NamedTuple, Optional)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
from typing import Set, Tuple, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
@ -1816,6 +1816,17 @@ class LLMEngine:
|
||||
def stop_profile(self) -> None:
|
||||
self.model_executor.stop_profile()
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
"""
|
||||
See LLM.collective_rpc for more details.
|
||||
"""
|
||||
return self.model_executor.collective_rpc(method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
def check_health(self) -> None:
|
||||
if self.tokenizer:
|
||||
self.tokenizer.check_health()
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import itertools
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
|
||||
Union, cast, overload)
|
||||
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
|
||||
Tuple, Type, Union, cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
from tqdm import tqdm
|
||||
@ -464,7 +464,7 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
@ -476,9 +476,13 @@ class LLM:
|
||||
Then, users can call the new methods through this API.
|
||||
It is recommended to use this API to only pass control messages,
|
||||
and set up data-plane communication to pass data.
|
||||
The method can also be a callable, which will be serialized
|
||||
and sent to all workers to execute.
|
||||
If the method is a callable, it should accept an additional
|
||||
`self` argument, in addition to the arguments passed in `args`
|
||||
and `kwargs`. The `self` argument will be the worker object.
|
||||
"""
|
||||
return self.llm_engine.model_executor.collective_rpc(
|
||||
method, timeout, args, kwargs)
|
||||
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||
Union)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -47,7 +48,7 @@ class ExecutorBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
@ -260,7 +261,7 @@ class DistributedExecutorBase(ExecutorBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
@ -269,7 +270,7 @@ class DistributedExecutorBase(ExecutorBase):
|
||||
@abstractmethod
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
|
||||
from vllm.executor.executor_base import DistributedExecutorBase
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
@ -9,7 +11,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async)
|
||||
get_ip, get_open_port, make_async, run_method)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -107,7 +109,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
@ -121,6 +123,11 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
if isinstance(method, str):
|
||||
sent_method = method
|
||||
else:
|
||||
sent_method = cloudpickle.dumps(method)
|
||||
del method
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
@ -129,18 +136,18 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Run only non-driver workers and just return futures.
|
||||
return [
|
||||
worker.execute_method(method, *args, **kwargs)
|
||||
worker.execute_method(sent_method, *args, **kwargs)
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
|
||||
# Start all remote workers first.
|
||||
worker_outputs = [
|
||||
worker.execute_method(method, *args, **kwargs)
|
||||
worker.execute_method(sent_method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
driver_worker_method = getattr(self.driver_worker, method)
|
||||
driver_worker_output = driver_worker_method(*args, **kwargs)
|
||||
driver_worker_output = run_method(self.driver_worker, sent_method,
|
||||
args, kwargs)
|
||||
|
||||
# Get the results of the workers.
|
||||
return [driver_worker_output
|
||||
|
||||
@ -15,7 +15,7 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.utils import _check_multiproc_method, get_mp_context
|
||||
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||
@ -169,7 +169,7 @@ class ProcessWorkerWrapper:
|
||||
self.process.start()
|
||||
|
||||
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
|
||||
method: str, args, kwargs):
|
||||
method: Union[str, bytes], args, kwargs):
|
||||
task_id = uuid.uuid4()
|
||||
self.tasks[task_id] = future
|
||||
try:
|
||||
@ -180,12 +180,13 @@ class ProcessWorkerWrapper:
|
||||
del self.tasks[task_id]
|
||||
raise ChildProcessError("worker died") from e
|
||||
|
||||
def execute_method(self, method: str, *args, **kwargs):
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
future: ResultFuture = ResultFuture()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return future
|
||||
|
||||
async def execute_method_async(self, method: str, *args, **kwargs):
|
||||
async def execute_method_async(self, method: Union[str, bytes], *args,
|
||||
**kwargs):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return await future
|
||||
@ -230,8 +231,7 @@ def _run_worker_process(
|
||||
exception = None
|
||||
task_id, method, args, kwargs = items
|
||||
try:
|
||||
executor = getattr(worker, method)
|
||||
output = executor(*args, **kwargs)
|
||||
output = run_method(worker, method, args, kwargs)
|
||||
except SystemExit:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
|
||||
@ -2,8 +2,9 @@ import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -410,7 +411,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
@ -426,6 +427,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
"""
|
||||
if isinstance(method, str):
|
||||
sent_method = method
|
||||
else:
|
||||
sent_method = cloudpickle.dumps(method)
|
||||
del method
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
"async_run_tensor_parallel_workers_only is not supported for "
|
||||
@ -440,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *args, **kwargs)
|
||||
worker.execute_method.remote(sent_method, *args, **kwargs)
|
||||
for worker in ray_workers
|
||||
]
|
||||
|
||||
@ -455,7 +461,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
if not self.use_ray_spmd_worker:
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = [
|
||||
self.driver_worker.execute_method(method, *args, **kwargs)
|
||||
self.driver_worker.execute_method(sent_method, *args, **kwargs)
|
||||
]
|
||||
|
||||
# Get the results of the ray workers.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -7,7 +7,8 @@ import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
run_method)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -39,18 +40,13 @@ class UniProcExecutor(ExecutorBase):
|
||||
self.collective_rpc("load_model")
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
try:
|
||||
func = getattr(self.driver_worker, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"Method {method} is not implemented.") \
|
||||
from None
|
||||
answer = func(*args, **kwargs)
|
||||
answer = run_method(self.driver_worker, method, args, kwargs)
|
||||
return [answer]
|
||||
|
||||
def check_health(self) -> None:
|
||||
|
||||
@ -36,6 +36,7 @@ from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import cloudpickle
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import psutil
|
||||
@ -2166,3 +2167,25 @@ def bind_kv_cache(
|
||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||
|
||||
|
||||
def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
|
||||
kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Run a method of an object with the given arguments and keyword arguments.
|
||||
If the method is string, it will be converted to a method using getattr.
|
||||
If the method is serialized bytes and will be deserialized using
|
||||
cloudpickle.
|
||||
If the method is a callable, it will be called directly.
|
||||
"""
|
||||
if isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), obj)
|
||||
elif isinstance(method, str):
|
||||
try:
|
||||
func = getattr(obj, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(f"Method {method!r} is not"
|
||||
" implemented.") from None
|
||||
else:
|
||||
func = partial(method, obj) # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -6,9 +6,11 @@ import time
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
import zmq
|
||||
|
||||
@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: Tuple = (),
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
try:
|
||||
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(
|
||||
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
|
||||
|
||||
responses = [None] * self.world_size
|
||||
for w in self.workers:
|
||||
@ -408,7 +415,11 @@ class WorkerProc:
|
||||
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
||||
|
||||
try:
|
||||
output = getattr(self.worker, method)(*args, **kwargs)
|
||||
if isinstance(method, str):
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self.worker_response_mq.enqueue(
|
||||
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname, update_environment_variables)
|
||||
resolve_obj_by_qualname, run_method,
|
||||
update_environment_variables)
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
@ -539,17 +540,16 @@ class WorkerWrapperBase:
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def execute_method(self, method: str, *args, **kwargs):
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
try:
|
||||
target = self if self.worker is None else self.worker
|
||||
executor = getattr(target, method)
|
||||
return executor(*args, **kwargs)
|
||||
return run_method(target, method, args, kwargs)
|
||||
except Exception as e:
|
||||
# if the driver worker also execute methods,
|
||||
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||
# see https://github.com/vllm-project/vllm/issues/3455
|
||||
# print the error and inform the user to solve the error
|
||||
msg = (f"Error executing method {method}. "
|
||||
msg = (f"Error executing method {method!r}. "
|
||||
"This might cause deadlock in distributed execution.")
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user