From 94744ba41a2807cb195e4a41a85d4d49f6867967 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Sat, 29 Mar 2025 05:39:14 -0500 Subject: [PATCH] [V1] [Feature] Collective RPC (#15444) Signed-off-by: wwl2755 --- .buildkite/test-pipeline.yaml | 6 ++--- vllm/engine/llm_engine.py | 13 +++++++++-- vllm/entrypoints/llm.py | 4 ++-- vllm/v1/engine/core.py | 12 +++++++++- vllm/v1/engine/core_client.py | 43 ++++++++++++++++++++++++++++++++++- vllm/v1/engine/llm_engine.py | 10 +++++++- vllm/v1/serial_utils.py | 8 +++++++ 7 files changed, 86 insertions(+), 10 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 428b4c593c38e..62872bf8e3e18 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -150,8 +150,8 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference - - VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py - - VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - python3 rlhf.py + - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd - label: Metrics, Tracing Test # 10min @@ -520,7 +520,7 @@ steps: - vllm/v1/engine/ commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - - VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py + - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5682b3dabe2e8..10677878ecc8f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -7,8 +7,8 @@ from collections import deque from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, - List, Mapping, NamedTuple, Optional) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Type, Union, cast, overload @@ -67,6 +67,7 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput) +_R = TypeVar("_R", default=Any) @dataclass @@ -2123,6 +2124,14 @@ class LLMEngine: return sampling_params + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1887caf25a30f..7c354be2d45c5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -492,8 +492,8 @@ class LLM: It is recommended to use this API to only pass control messages, and set up data-plane communication to pass data. """ - executor = self.llm_engine.model_executor - return executor.collective_rpc(method, timeout, args, kwargs) + + return self.llm_engine.collective_rpc(method, timeout, args, kwargs) def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: """ diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 20904cd495f91..6083eea45cd98 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -8,7 +8,7 @@ import time from concurrent.futures import Future from inspect import isclass, signature from logging import DEBUG -from typing import Any, Optional +from typing import Any, Callable, Optional, TypeVar, Union import msgspec import psutil @@ -43,6 +43,8 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 +_R = TypeVar('_R') # Return type for collective_rpc + class EngineCore: """Inner loop of vLLM's Engine.""" @@ -280,6 +282,14 @@ class EngineCore: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 8858a564d2c2b..3dc33a1284a12 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -12,7 +12,7 @@ from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass, field from threading import Thread -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union import zmq import zmq.asyncio @@ -33,6 +33,8 @@ logger = init_logger(__name__) AnyFuture = Union[asyncio.Future[Any], Future[Any]] +_R = TypeVar('_R') # Return type for collective_rpc + class EngineCoreClient(ABC): """ @@ -117,6 +119,13 @@ class EngineCoreClient(ABC): def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -153,6 +162,14 @@ class EngineCoreClient(ABC): async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError + async def collective_rpc_async( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -210,6 +227,13 @@ class InprocClient(EngineCoreClient): def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.engine_core.collective_rpc(method, timeout, args, kwargs) + class CoreEngine: """One per data parallel rank.""" @@ -505,6 +529,14 @@ class SyncMPClient(MPClient): def execute_dummy_batch(self) -> None: self.call_utility("execute_dummy_batch") + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.call_utility("collective_rpc", method, timeout, args, + kwargs) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -636,6 +668,15 @@ class AsyncMPClient(MPClient): async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) + async def collective_rpc_async( + self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return await self.call_utility_async("collective_rpc", method, timeout, + args, kwargs) + class DPAsyncMPClient(AsyncMPClient): """Asyncio-compatible client for multi-proc, multi-engine (data parallel) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 000de21fbe7bf..764c643b5c974 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from copy import copy -from typing import Optional, Union +from typing import Any, Callable, Optional, Union from typing_extensions import TypeVar @@ -32,6 +32,7 @@ from vllm.v1.executor.abstract import Executor logger = init_logger(__name__) _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +_R = TypeVar("_R", default=Any) class LLMEngine: @@ -282,6 +283,13 @@ class LLMEngine: """Prevent an adapter from being evicted.""" return self.engine_core.pin_lora(lora_id) + def collective_rpc(self, + method: Union[str, Callable[..., _R]], + timeout: Optional[float] = None, + args: tuple = (), + kwargs: Optional[dict[str, Any]] = None) -> list[_R]: + return self.engine_core.collective_rpc(method, timeout, args, kwargs) + def __del__(self): if dp_group := getattr(self, "dp_group", None): stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 3f000abcde0d1..146d7d747f1a4 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import pickle +from types import FunctionType from typing import Any, Optional +import cloudpickle import torch from msgspec import msgpack CUSTOM_TYPE_TENSOR = 1 CUSTOM_TYPE_PICKLE = 2 +CUSTOM_TYPE_CLOUDPICKLE = 3 class MsgpackEncoder: @@ -41,6 +44,9 @@ def custom_enc_hook(obj: Any) -> Any: # https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501 return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy())) + if isinstance(obj, FunctionType): + return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj)) + return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj)) @@ -49,5 +55,7 @@ def custom_ext_hook(code: int, data: memoryview) -> Any: return torch.from_numpy(pickle.loads(data)) if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) + if code == CUSTOM_TYPE_CLOUDPICKLE: + return cloudpickle.loads(data) raise NotImplementedError(f"Extension type code {code} is not supported")