mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:45:02 +08:00
[core] allow callable in collective_rpc (#12151)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d4e6194570
commit
87a0c076af
@ -107,7 +107,7 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
|
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||||
@ -466,7 +466,9 @@ steps:
|
|||||||
- vllm/worker/worker_base.py
|
- vllm/worker/worker_base.py
|
||||||
- vllm/worker/worker.py
|
- vllm/worker/worker.py
|
||||||
- vllm/worker/model_runner.py
|
- vllm/worker/model_runner.py
|
||||||
|
- entrypoints/llm/test_collective_rpc.py
|
||||||
commands:
|
commands:
|
||||||
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -18,7 +18,7 @@ class Mock:
|
|||||||
class CustomUniExecutor(UniProcExecutor):
|
class CustomUniExecutor(UniProcExecutor):
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
|
|||||||
36
tests/entrypoints/llm/test_collective_rpc.py
Normal file
36
tests/entrypoints/llm/test_collective_rpc.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
from ...utils import fork_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||||
|
@pytest.mark.parametrize("backend", ["mp", "ray"])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_collective_rpc(tp_size, backend):
|
||||||
|
if tp_size == 1 and backend == "ray":
|
||||||
|
pytest.skip("Skip duplicate test case")
|
||||||
|
if tp_size == 1:
|
||||||
|
backend = None
|
||||||
|
|
||||||
|
# intentionally define the method and class in the test function,
|
||||||
|
# to test if they can be serialized and sent to the workers
|
||||||
|
def echo_rank(self):
|
||||||
|
return self.rank
|
||||||
|
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
class MyWorker(Worker):
|
||||||
|
|
||||||
|
def echo_rank(self):
|
||||||
|
return self.rank
|
||||||
|
|
||||||
|
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
enforce_eager=True,
|
||||||
|
load_format="dummy",
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
distributed_executor_backend=backend,
|
||||||
|
worker_cls=MyWorker)
|
||||||
|
for method in ["echo_rank", echo_rank]:
|
||||||
|
assert llm.collective_rpc(method) == list(range(tp_size))
|
||||||
@ -5,10 +5,10 @@ from collections import deque
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||||
List, Mapping, NamedTuple, Optional)
|
Iterable, List, Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, Union, cast, overload
|
from typing import Set, Tuple, Type, Union, cast, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeVar, deprecated
|
from typing_extensions import TypeVar, deprecated
|
||||||
@ -1816,6 +1816,17 @@ class LLMEngine:
|
|||||||
def stop_profile(self) -> None:
|
def stop_profile(self) -> None:
|
||||||
self.model_executor.stop_profile()
|
self.model_executor.stop_profile()
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: Union[str, Callable],
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: Tuple = (),
|
||||||
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
|
"""
|
||||||
|
See LLM.collective_rpc for more details.
|
||||||
|
"""
|
||||||
|
return self.model_executor.collective_rpc(method, timeout, args,
|
||||||
|
kwargs)
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
if self.tokenizer:
|
if self.tokenizer:
|
||||||
self.tokenizer.check_health()
|
self.tokenizer.check_health()
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
|
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
|
||||||
Union, cast, overload)
|
Tuple, Type, Union, cast, overload)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -464,7 +464,7 @@ class LLM:
|
|||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
@ -476,9 +476,13 @@ class LLM:
|
|||||||
Then, users can call the new methods through this API.
|
Then, users can call the new methods through this API.
|
||||||
It is recommended to use this API to only pass control messages,
|
It is recommended to use this API to only pass control messages,
|
||||||
and set up data-plane communication to pass data.
|
and set up data-plane communication to pass data.
|
||||||
|
The method can also be a callable, which will be serialized
|
||||||
|
and sent to all workers to execute.
|
||||||
|
If the method is a callable, it should accept an additional
|
||||||
|
`self` argument, in addition to the arguments passed in `args`
|
||||||
|
and `kwargs`. The `self` argument will be the worker object.
|
||||||
"""
|
"""
|
||||||
return self.llm_engine.model_executor.collective_rpc(
|
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
|
||||||
method, timeout, args, kwargs)
|
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
|
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||||
|
Union)
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -47,7 +48,7 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
@ -260,7 +261,7 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
@ -269,7 +270,7 @@ class DistributedExecutorBase(ExecutorBase):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
*args,
|
*args,
|
||||||
async_run_tensor_parallel_workers_only: bool = False,
|
async_run_tensor_parallel_workers_only: bool = False,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
|
|
||||||
from vllm.executor.executor_base import DistributedExecutorBase
|
from vllm.executor.executor_base import DistributedExecutorBase
|
||||||
from vllm.executor.multiproc_worker_utils import (
|
from vllm.executor.multiproc_worker_utils import (
|
||||||
@ -9,7 +11,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||||
get_ip, get_open_port, make_async)
|
get_ip, get_open_port, make_async, run_method)
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -107,7 +109,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
*args,
|
*args,
|
||||||
async_run_tensor_parallel_workers_only: bool = False,
|
async_run_tensor_parallel_workers_only: bool = False,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
@ -121,6 +123,11 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
|||||||
It will also be run asynchronously and return a list of futures
|
It will also be run asynchronously and return a list of futures
|
||||||
rather than blocking on the results.
|
rather than blocking on the results.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(method, str):
|
||||||
|
sent_method = method
|
||||||
|
else:
|
||||||
|
sent_method = cloudpickle.dumps(method)
|
||||||
|
del method
|
||||||
|
|
||||||
if max_concurrent_workers:
|
if max_concurrent_workers:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -129,18 +136,18 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
|||||||
if async_run_tensor_parallel_workers_only:
|
if async_run_tensor_parallel_workers_only:
|
||||||
# Run only non-driver workers and just return futures.
|
# Run only non-driver workers and just return futures.
|
||||||
return [
|
return [
|
||||||
worker.execute_method(method, *args, **kwargs)
|
worker.execute_method(sent_method, *args, **kwargs)
|
||||||
for worker in self.non_driver_workers
|
for worker in self.non_driver_workers
|
||||||
]
|
]
|
||||||
|
|
||||||
# Start all remote workers first.
|
# Start all remote workers first.
|
||||||
worker_outputs = [
|
worker_outputs = [
|
||||||
worker.execute_method(method, *args, **kwargs)
|
worker.execute_method(sent_method, *args, **kwargs)
|
||||||
for worker in self.workers
|
for worker in self.workers
|
||||||
]
|
]
|
||||||
|
|
||||||
driver_worker_method = getattr(self.driver_worker, method)
|
driver_worker_output = run_method(self.driver_worker, sent_method,
|
||||||
driver_worker_output = driver_worker_method(*args, **kwargs)
|
args, kwargs)
|
||||||
|
|
||||||
# Get the results of the workers.
|
# Get the results of the workers.
|
||||||
return [driver_worker_output
|
return [driver_worker_output
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import torch
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.triton_utils.importing import HAS_TRITON
|
from vllm.triton_utils.importing import HAS_TRITON
|
||||||
from vllm.utils import _check_multiproc_method, get_mp_context
|
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||||
@ -169,7 +169,7 @@ class ProcessWorkerWrapper:
|
|||||||
self.process.start()
|
self.process.start()
|
||||||
|
|
||||||
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
|
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
|
||||||
method: str, args, kwargs):
|
method: Union[str, bytes], args, kwargs):
|
||||||
task_id = uuid.uuid4()
|
task_id = uuid.uuid4()
|
||||||
self.tasks[task_id] = future
|
self.tasks[task_id] = future
|
||||||
try:
|
try:
|
||||||
@ -180,12 +180,13 @@ class ProcessWorkerWrapper:
|
|||||||
del self.tasks[task_id]
|
del self.tasks[task_id]
|
||||||
raise ChildProcessError("worker died") from e
|
raise ChildProcessError("worker died") from e
|
||||||
|
|
||||||
def execute_method(self, method: str, *args, **kwargs):
|
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||||
future: ResultFuture = ResultFuture()
|
future: ResultFuture = ResultFuture()
|
||||||
self._enqueue_task(future, method, args, kwargs)
|
self._enqueue_task(future, method, args, kwargs)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
async def execute_method_async(self, method: str, *args, **kwargs):
|
async def execute_method_async(self, method: Union[str, bytes], *args,
|
||||||
|
**kwargs):
|
||||||
future = asyncio.get_running_loop().create_future()
|
future = asyncio.get_running_loop().create_future()
|
||||||
self._enqueue_task(future, method, args, kwargs)
|
self._enqueue_task(future, method, args, kwargs)
|
||||||
return await future
|
return await future
|
||||||
@ -230,8 +231,7 @@ def _run_worker_process(
|
|||||||
exception = None
|
exception = None
|
||||||
task_id, method, args, kwargs = items
|
task_id, method, args, kwargs = items
|
||||||
try:
|
try:
|
||||||
executor = getattr(worker, method)
|
output = run_method(worker, method, args, kwargs)
|
||||||
output = executor(*args, **kwargs)
|
|
||||||
except SystemExit:
|
except SystemExit:
|
||||||
raise
|
raise
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|||||||
@ -2,8 +2,9 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -410,7 +411,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
*args,
|
*args,
|
||||||
async_run_tensor_parallel_workers_only: bool = False,
|
async_run_tensor_parallel_workers_only: bool = False,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
@ -426,6 +427,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
rather than blocking on the results.
|
rather than blocking on the results.
|
||||||
- args/kwargs: All workers share the same args/kwargs
|
- args/kwargs: All workers share the same args/kwargs
|
||||||
"""
|
"""
|
||||||
|
if isinstance(method, str):
|
||||||
|
sent_method = method
|
||||||
|
else:
|
||||||
|
sent_method = cloudpickle.dumps(method)
|
||||||
|
del method
|
||||||
if self.use_ray_spmd_worker:
|
if self.use_ray_spmd_worker:
|
||||||
assert not async_run_tensor_parallel_workers_only, (
|
assert not async_run_tensor_parallel_workers_only, (
|
||||||
"async_run_tensor_parallel_workers_only is not supported for "
|
"async_run_tensor_parallel_workers_only is not supported for "
|
||||||
@ -440,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
if async_run_tensor_parallel_workers_only:
|
if async_run_tensor_parallel_workers_only:
|
||||||
ray_workers = self.non_driver_workers
|
ray_workers = self.non_driver_workers
|
||||||
ray_worker_outputs = [
|
ray_worker_outputs = [
|
||||||
worker.execute_method.remote(method, *args, **kwargs)
|
worker.execute_method.remote(sent_method, *args, **kwargs)
|
||||||
for worker in ray_workers
|
for worker in ray_workers
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -455,7 +461,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
if not self.use_ray_spmd_worker:
|
if not self.use_ray_spmd_worker:
|
||||||
# Start the driver worker after all the ray workers.
|
# Start the driver worker after all the ray workers.
|
||||||
driver_worker_output = [
|
driver_worker_output = [
|
||||||
self.driver_worker.execute_method(method, *args, **kwargs)
|
self.driver_worker.execute_method(sent_method, *args, **kwargs)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Get the results of the ray workers.
|
# Get the results of the ray workers.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -7,7 +7,8 @@ import torch.distributed as dist
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
|
run_method)
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -39,18 +40,13 @@ class UniProcExecutor(ExecutorBase):
|
|||||||
self.collective_rpc("load_model")
|
self.collective_rpc("load_model")
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
try:
|
answer = run_method(self.driver_worker, method, args, kwargs)
|
||||||
func = getattr(self.driver_worker, method)
|
|
||||||
except AttributeError:
|
|
||||||
raise NotImplementedError(f"Method {method} is not implemented.") \
|
|
||||||
from None
|
|
||||||
answer = func(*args, **kwargs)
|
|
||||||
return [answer]
|
return [answer]
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
|||||||
overload)
|
overload)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import psutil
|
import psutil
|
||||||
@ -2166,3 +2167,25 @@ def bind_kv_cache(
|
|||||||
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
assert len(forward_ctx.kv_cache) == len(kv_cache)
|
||||||
for ve, ve_kv_cache in enumerate(kv_cache):
|
for ve, ve_kv_cache in enumerate(kv_cache):
|
||||||
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx]
|
||||||
|
|
||||||
|
|
||||||
|
def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
|
||||||
|
kwargs: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Run a method of an object with the given arguments and keyword arguments.
|
||||||
|
If the method is string, it will be converted to a method using getattr.
|
||||||
|
If the method is serialized bytes and will be deserialized using
|
||||||
|
cloudpickle.
|
||||||
|
If the method is a callable, it will be called directly.
|
||||||
|
"""
|
||||||
|
if isinstance(method, bytes):
|
||||||
|
func = partial(cloudpickle.loads(method), obj)
|
||||||
|
elif isinstance(method, str):
|
||||||
|
try:
|
||||||
|
func = getattr(obj, method)
|
||||||
|
except AttributeError:
|
||||||
|
raise NotImplementedError(f"Method {method!r} is not"
|
||||||
|
" implemented.") from None
|
||||||
|
else:
|
||||||
|
func = partial(method, obj) # type: ignore
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|||||||
@ -6,9 +6,11 @@ import time
|
|||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
|
from functools import partial
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import cloudpickle
|
||||||
import psutil
|
import psutil
|
||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
@ -120,7 +122,7 @@ class MultiprocExecutor(Executor):
|
|||||||
return kv_cache_specs[0]
|
return kv_cache_specs[0]
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: str,
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
@ -141,7 +143,12 @@ class MultiprocExecutor(Executor):
|
|||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
|
if isinstance(method, str):
|
||||||
|
send_method = method
|
||||||
|
else:
|
||||||
|
send_method = cloudpickle.dumps(
|
||||||
|
method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs))
|
||||||
|
|
||||||
responses = [None] * self.world_size
|
responses = [None] * self.world_size
|
||||||
for w in self.workers:
|
for w in self.workers:
|
||||||
@ -408,7 +415,11 @@ class WorkerProc:
|
|||||||
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = getattr(self.worker, method)(*args, **kwargs)
|
if isinstance(method, str):
|
||||||
|
func = getattr(self.worker, method)
|
||||||
|
elif isinstance(method, bytes):
|
||||||
|
func = partial(cloudpickle.loads(method), self.worker)
|
||||||
|
output = func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.worker_response_mq.enqueue(
|
self.worker_response_mq.enqueue(
|
||||||
(WorkerProc.ResponseStatus.FAILURE, e))
|
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||||
|
|||||||
@ -14,7 +14,8 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||||
resolve_obj_by_qualname, update_environment_variables)
|
resolve_obj_by_qualname, run_method,
|
||||||
|
update_environment_variables)
|
||||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||||
ModelRunnerBase,
|
ModelRunnerBase,
|
||||||
ModelRunnerInputBase)
|
ModelRunnerInputBase)
|
||||||
@ -539,17 +540,16 @@ class WorkerWrapperBase:
|
|||||||
self.worker = worker_class(**kwargs)
|
self.worker = worker_class(**kwargs)
|
||||||
assert self.worker is not None
|
assert self.worker is not None
|
||||||
|
|
||||||
def execute_method(self, method: str, *args, **kwargs):
|
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
target = self if self.worker is None else self.worker
|
target = self if self.worker is None else self.worker
|
||||||
executor = getattr(target, method)
|
return run_method(target, method, args, kwargs)
|
||||||
return executor(*args, **kwargs)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# if the driver worker also execute methods,
|
# if the driver worker also execute methods,
|
||||||
# exceptions in the rest worker may cause deadlock in rpc like ray
|
# exceptions in the rest worker may cause deadlock in rpc like ray
|
||||||
# see https://github.com/vllm-project/vllm/issues/3455
|
# see https://github.com/vllm-project/vllm/issues/3455
|
||||||
# print the error and inform the user to solve the error
|
# print the error and inform the user to solve the error
|
||||||
msg = (f"Error executing method {method}. "
|
msg = (f"Error executing method {method!r}. "
|
||||||
"This might cause deadlock in distributed execution.")
|
"This might cause deadlock in distributed execution.")
|
||||||
logger.exception(msg)
|
logger.exception(msg)
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user