From b712be98c790794479030313f2c2b9dae17ea7de Mon Sep 17 00:00:00 2001 From: Yan Ru Pei Date: Tue, 3 Jun 2025 17:14:20 -0700 Subject: [PATCH] feat: add data parallel rank to KVEventBatch (#18925) --- .buildkite/test-pipeline.yaml | 2 + tests/distributed/conftest.py | 101 ++++++----- tests/distributed/test_events.py | 69 +++++++- tests/v1/engine/test_engine_core_client.py | 189 +++++++++++++++++---- vllm/distributed/kv_events.py | 77 ++++++++- vllm/v1/core/sched/scheduler.py | 4 +- 6 files changed, 359 insertions(+), 83 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5fb8ceaace05d..8ab96b3b7ac3c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -145,6 +145,7 @@ steps: - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - tests/v1/test_async_llm_dp.py + - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py @@ -154,6 +155,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py index 95f085788b856..666a715cc0da1 100644 --- a/tests/distributed/conftest.py +++ b/tests/distributed/conftest.py @@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory from .test_events import SampleBatch +DP_RANK = 0 + @pytest.fixture def random_port(): """Generate a random port number for testing""" - return random.randint(10000, 60000) + return random.randint(10000, 59900) @pytest.fixture @@ -30,21 +32,23 @@ def publisher_config(random_port, request): replay_endpoint = endpoint + "-replay" else: endpoint = f"tcp://*:{random_port}" - replay_endpoint = f"tcp://*:{random_port + 1}" + replay_endpoint = f"tcp://*:{random_port + 100}" - return KVEventsConfig(enable_kv_cache_events=True, - publisher="zmq", - endpoint=endpoint, - replay_endpoint=replay_endpoint, - buffer_steps=100, - hwm=1000, - topic="test") + return KVEventsConfig( + enable_kv_cache_events=True, + publisher="zmq", + endpoint=endpoint, + replay_endpoint=replay_endpoint, + buffer_steps=100, + hwm=1000, + topic="test", + ) @pytest.fixture def publisher(publisher_config): """Create and return a publisher instance""" - pub = EventPublisherFactory.create(publisher_config) + pub = EventPublisherFactory.create(publisher_config, DP_RANK) yield pub pub.shutdown() @@ -60,7 +64,11 @@ def subscriber(publisher_config): if replay_endpoint and replay_endpoint.startswith("tcp://*"): replay_endpoint = replay_endpoint.replace("*", "127.0.0.1") - sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic) + sub = MockSubscriber( + [endpoint], + [replay_endpoint] if replay_endpoint else None, + publisher_config.topic, + ) yield sub sub.close() @@ -68,26 +76,37 @@ def subscriber(publisher_config): class MockSubscriber: """Helper class to receive and verify published events""" - def __init__(self, - pub_endpoint: str, - replay_endpoint: Optional[str] = None, - topic: str = "", - decode_type=SampleBatch): + def __init__( + self, + pub_endpoints: Union[str, list[str]], + replay_endpoints: Optional[Union[str, list[str]]] = None, + topic: str = "", + decode_type=SampleBatch, + ): self.ctx = zmq.Context.instance() - # Set up subscriber socket - self.sub = self.ctx.socket(zmq.SUB) - self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8')) - self.sub.connect(pub_endpoint) + # Convert single endpoint to list for consistency + if isinstance(pub_endpoints, str): + pub_endpoints = [pub_endpoints] + if isinstance(replay_endpoints, str): + replay_endpoints = [replay_endpoints] - # Set up replay socket if provided - self.replay = None - if replay_endpoint: - self.replay = self.ctx.socket(zmq.REQ) - self.replay.connect(replay_endpoint) + # Set up subscriber socket - connect to all endpoints + self.sub = self.ctx.socket(zmq.SUB) + self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8")) + for endpoint in pub_endpoints: + self.sub.connect(endpoint) + + # Set up replay sockets if provided + self.replay_sockets = [] + if replay_endpoints: + for replay_endpoint in replay_endpoints: + replay = self.ctx.socket(zmq.REQ) + replay.connect(replay_endpoint) + self.replay_sockets.append(replay) self.topic = topic - self.topic_bytes = topic.encode('utf-8') + self.topic_bytes = topic.encode("utf-8") self.received_msgs: list[tuple[int, SampleBatch]] = [] self.last_seq = -1 self.decoder = msgspec.msgpack.Decoder(type=decode_type) @@ -107,25 +126,31 @@ class MockSubscriber: self.received_msgs.append((seq, data)) return seq, data - def request_replay(self, start_seq: int) -> None: + def request_replay(self, start_seq: int, socket_idx: int = 0) -> None: """Request replay of messages starting from start_seq""" - if not self.replay: - raise ValueError("Replay socket not initialized") + if not self.replay_sockets: + raise ValueError("Replay sockets not initialized") + if socket_idx >= len(self.replay_sockets): + raise ValueError(f"Invalid socket index {socket_idx}") - self.replay.send(start_seq.to_bytes(8, "big")) + self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big")) - def receive_replay(self) -> list[tuple[int, SampleBatch]]: - """Receive replayed messages""" - if not self.replay: - raise ValueError("Replay socket not initialized") + def receive_replay(self, + socket_idx: int = 0) -> list[tuple[int, SampleBatch]]: + """Receive replayed messages from a specific replay socket""" + if not self.replay_sockets: + raise ValueError("Replay sockets not initialized") + if socket_idx >= len(self.replay_sockets): + raise ValueError(f"Invalid socket index {socket_idx}") + replay_socket = self.replay_sockets[socket_idx] replayed: list[tuple[int, SampleBatch]] = [] while True: try: - if not self.replay.poll(1000): + if not replay_socket.poll(1000): break - frames = self.replay.recv_multipart() + frames = replay_socket.recv_multipart() if not frames or not frames[-1]: # End of replay marker break @@ -142,5 +167,5 @@ class MockSubscriber: def close(self): """Clean up resources""" self.sub.close() - if self.replay: - self.replay.close() + for replay in self.replay_sockets: + replay.close() diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py index ec1e5a2d62f11..8be9ee0a1889d 100644 --- a/tests/distributed/test_events.py +++ b/tests/distributed/test_events.py @@ -9,6 +9,8 @@ import pytest from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, NullEventPublisher) +DP_RANK = 0 + class EventSample( msgspec.Struct, @@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config): publisher_config.replay_endpoint = None publisher_config.topic = "foo" - pub = EventPublisherFactory.create(publisher_config) + pub = EventPublisherFactory.create(publisher_config, DP_RANK) from .conftest import MockSubscriber sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") @@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber): def test_null_publisher(): """Test that NullEventPublisher can be used without errors""" - publisher = NullEventPublisher() + publisher = NullEventPublisher(DP_RANK) # This should not raise any errors batch = create_test_events(5) publisher.publish(batch) publisher.shutdown() + + +def test_data_parallel_rank_tagging(publisher_config): + """Test that events are properly tagged with their data parallel rank""" + + publisher_config.topic = "foo" + pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK) + pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1) + + # Hardcode the expected endpoints based on port offsetting behavior + # Both ranks get offsets according to _offset_endpoint_port function + base_endpoint = publisher_config.endpoint + if "tcp://" in base_endpoint: + # For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558 + expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port + expected_endpoint_1 = base_endpoint.replace( + ":5557", ":5558") # rank 1 gets port + 1 + else: + # For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1 + expected_endpoint_0 = base_endpoint # rank 0 gets base + expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1 + + from .conftest import MockSubscriber + sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic) + sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic) + + try: + time.sleep(0.1) # Let publishers start up + + # Publish events from different ranks + batch_0 = create_test_events(2) + batch_1 = create_test_events(3) + + pub_0.publish(batch_0) + pub_1.publish(batch_1) + + # Receive events from rank 0 + result_0 = sub_0.receive_one(timeout=200) + assert result_0 is not None, "No message received from rank 0" + seq_0, received_0 = result_0 + + # Receive events from rank 1 + result_1 = sub_1.receive_one(timeout=200) + assert result_1 is not None, "No message received from rank 1" + seq_1, received_1 = result_1 + + # Verify DP rank tagging + assert received_0.data_parallel_rank == 0, ( + f"Expected DP rank 0, got {received_0.data_parallel_rank}") + assert received_1.data_parallel_rank == 1, ( + f"Expected DP rank 1, got {received_1.data_parallel_rank}") + + # Verify event content is correct + assert len( + received_0.events) == 2, "Wrong number of events from rank 0" + assert len( + received_1.events) == 3, "Wrong number of events from rank 1" + + finally: + pub_0.shutdown() + pub_1.shutdown() + sub_0.close() + sub_1.close() diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index a01b205dfaed5..47181d36f4ccc 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -12,8 +12,10 @@ from typing import Optional import pytest from transformers import AutoTokenizer +from tests.utils import multi_gpu_test from vllm import SamplingParams -from vllm.distributed.kv_events import BlockStored, KVEventBatch +from vllm.distributed.kv_events import (BlockStored, KVEventBatch, + ZmqEventPublisher) from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext @@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels" PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids -def make_request(params: SamplingParams) -> EngineCoreRequest: +def make_request( + params: SamplingParams, + prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest: + if not prompt_tokens_ids: + prompt_tokens_ids = PROMPT_TOKENS + return EngineCoreRequest( request_id=str(uuid.uuid4()), - prompt_token_ids=PROMPT_TOKENS, + prompt_token_ids=prompt_tokens_ids, mm_inputs=None, mm_hashes=None, mm_placeholders=None, @@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict): break +async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict): + + while True: + engine_core_outputs = (await client.get_output_async()).outputs + + if len(engine_core_outputs) == 0: + continue + + # Add outputs to the dict + for out in engine_core_outputs: + outputs[out.request_id].append(out) + + # Check if all request IDs in outputs have finished + if all(outs and outs[-1].finished for outs in outputs.values()): + break + + await asyncio.sleep(0.1) + + # Dummy utility function to monkey-patch into engine core. def echo(self, msg: str, err_msg: Optional[str] = None) -> str: print(f"echo util function called: {msg}, {err_msg}") @@ -273,10 +299,12 @@ def test_kv_cache_events( block_size = 16 num_blocks = 2 - engine_args = EngineArgs(model=MODEL_NAME, - enforce_eager=True, - enable_prefix_caching=True, - block_size=block_size) + engine_args = EngineArgs( + model=MODEL_NAME, + enforce_eager=True, + enable_prefix_caching=True, + block_size=block_size, + ) engine_args.kv_events_config = publisher_config vllm_config = engine_args.create_engine_config( @@ -297,19 +325,8 @@ def test_kv_cache_events( try: custom_tokens = list(range(num_blocks * block_size)) - request = EngineCoreRequest( - request_id=str(uuid.uuid4()), - prompt_token_ids=custom_tokens, - mm_inputs=None, - mm_hashes=None, - mm_placeholders=None, - sampling_params=SamplingParams( - max_tokens=1), # Short completion for speed - eos_token_id=None, - arrival_time=time.time(), - lora_request=None, - cache_salt=None, - ) + sampling_params = SamplingParams(max_tokens=1) + request = make_request(sampling_params, custom_tokens) client.add_request(request) outputs: dict[str, list] = {request.request_id: []} @@ -321,24 +338,130 @@ def test_kv_cache_events( seq, received = result assert seq == 0, "Sequence number mismatch" - assert len(received.events) == 1, ( - "We should have exactly one BlockStored event") + assert (len(received.events) == 1 + ), "We should have exactly one BlockStored event" event = received.events[0] assert isinstance( - event, BlockStored), ("We should have a BlockStored event") - assert len(event.block_hashes) == num_blocks, ( - "We should have a BlockStored event with 2 block_hashes") - assert event.block_size == block_size, ( - "Block size should be the same as the block size") - assert event.parent_block_hash is None, ( - "Parent block hash should be None") + event, BlockStored), "We should have a BlockStored event" + assert (len(event.block_hashes) == num_blocks + ), "We should have a BlockStored event with 2 block_hashes" + assert (event.block_size == block_size + ), "Block size should be the same as the block size" + assert (event.parent_block_hash + is None), "Parent block hash should be None" assert event.lora_id is None, "Lora id should be None" - assert len(event.token_ids) == num_blocks * block_size, ( - "Token ids should be the same as the custom tokens") - assert event.token_ids == custom_tokens, ( - "Token ids should be the same as the custom tokens") + assert (len(event.token_ids) == num_blocks * block_size + ), "Token ids should be the same as the custom tokens" + assert (event.token_ids == custom_tokens + ), "Token ids should be the same as the custom tokens" finally: client.shutdown() + subscriber.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "multiprocessing_mode,publisher_config", + [(True, "tcp")], + indirect=["publisher_config"], +) +@multi_gpu_test(num_gpus=4) +async def test_kv_cache_events_dp( + monkeypatch: pytest.MonkeyPatch, + multiprocessing_mode: bool, + publisher_config, +): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + block_size = 16 + num_blocks = 2 + dp_size = 2 + tp_size = 2 + + engine_args = EngineArgs( + model=MODEL_NAME, + enforce_eager=True, + enable_prefix_caching=True, + data_parallel_size=dp_size, + tensor_parallel_size=tp_size, + block_size=block_size, + ) + engine_args.kv_events_config = publisher_config + + vllm_config = engine_args.create_engine_config( + UsageContext.UNKNOWN_CONTEXT) + + executor_class = Executor.get_class(vllm_config) + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=True, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + ) + await asyncio.sleep(1) + + # Build endpoints for all DP ranks + base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") + endpoints = [] + for i in range(dp_size): + offset_endpoint = ZmqEventPublisher.offset_endpoint_port( + base_endpoint, i) + endpoints.append(offset_endpoint) + + subscriber = MockSubscriber(endpoints, + topic=publisher_config.topic, + decode_type=KVEventBatch) + + try: + custom_tokens = list(range(num_blocks * block_size)) + sampling_params = SamplingParams(max_tokens=1) + all_request_ids = [] + + # Create and add 25 requests + # NOTE: attempts to force routing to both dp groups but can be flaky + for i in range(25): + await asyncio.sleep(0.01) + request = make_request(sampling_params, custom_tokens) + await client.add_request_async(request) + all_request_ids.append(request.request_id) + + await asyncio.sleep(0.1) + + # Initialize outputs dict for all requests + outputs: dict[str, list] = { + req_id: [] + for req_id in all_request_ids + } + + print("processing requests...") + await asyncio.wait_for(loop_until_fully_done_async( + client, outputs), + timeout=20.0) + + # Receive from subscriber until no more messages + print("collecting results...") + results = [] + while True: + result = subscriber.receive_one(timeout=1) + print(result) + if result is None: + break + results.append(result) + + # Collect all events and data_parallel_ranks from all results + all_dp_ranks = [ + received.data_parallel_rank for (_, received) in results + ] + unique_dps = set(all_dp_ranks) + assert ( + len(unique_dps) == 2 + ), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}" + + finally: + client.shutdown() + subscriber.close() @pytest.mark.timeout(20) diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 9bf1c058a1915..2d7935773dd9f 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -28,6 +28,7 @@ class EventBatch( ): ts: float events: list[Any] + data_parallel_rank: Optional[int] = None class KVCacheEvent( @@ -60,7 +61,22 @@ class KVEventBatch(EventBatch): class EventPublisher(ABC): - """Lightweight publisher for EventBatch batches.""" + """Lightweight publisher for EventBatch batches with data parallelism + support. + + In data parallel setups, each DP rank runs its own EventPublisher instance + to avoid duplicate events and ensure proper event attribution: + + - Each DP rank creates a separate publisher + - Publishers automatically annotate events with their data_parallel_rank + - This allows consumers to distinguish events from different DP ranks + + The publisher is responsible for adding DP metadata since the scheduler + operates independently of DP topology and shouldn't need DP awareness. + """ + + def __init__(self, data_parallel_rank: int = 0) -> None: + self._data_parallel_rank = data_parallel_rank @abstractmethod def publish(self, events: EventBatch) -> None: @@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher): def __init__( self, + data_parallel_rank: int, endpoint: str = "tcp://*:5557", replay_endpoint: Optional[str] = None, buffer_steps: int = 10_000, @@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher): topic: str = "", ) -> None: # Storage + super().__init__(data_parallel_rank) self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) @@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher): self._ctx = zmq.Context.instance() self._pub: Optional[zmq.Socket] = None self._replay: Optional[zmq.Socket] = None - self._endpoint = endpoint - self._replay_endpoint = replay_endpoint + self._dp_rank = data_parallel_rank + + self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank) + self._replay_endpoint = self.offset_endpoint_port( + replay_endpoint, self._dp_rank) self._hwm = hwm self._socket_setup() @@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher): def publish(self, events: EventBatch) -> None: if not self._running: raise RuntimeError("Publisher is closed") + if events.data_parallel_rank is None: + events.data_parallel_rank = self._data_parallel_rank self._event_queue.put(events) def shutdown(self) -> None: @@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher): self._pub.set_hwm(self._hwm) # Heuristic: bind if wildcard / * present, else connect. # bind stable, connect volatile convention - if ("*" in self._endpoint or "::" in self._endpoint - or self._endpoint.startswith("ipc://") - or self._endpoint.startswith("inproc://")): + if (self._endpoint is not None + and ("*" in self._endpoint or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://"))): self._pub.bind(self._endpoint) - else: + elif self._endpoint is not None: self._pub.connect(self._endpoint) # Set up replay socket: use ROUTER @@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher): # receiving payload is (-1, b""") self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) + @staticmethod + def offset_endpoint_port(endpoint: Optional[str], + data_parallel_rank: int) -> Optional[str]: + """Helper function to offset the port in an endpoint by + the data parallel rank. + + Args: + endpoint: The endpoint string + (e.g., "tcp://*:5557" or "inproc://cache") + data_parallel_rank: The data parallel rank to offset by + + Returns: + The endpoint with the port offset by data_parallel_rank + or suffix appended + """ + # Do nothing if input is None or data_parallel_rank is 0 + if not endpoint or data_parallel_rank == 0: + return endpoint + + if "inproc" in endpoint: + return f"{endpoint}_dp{data_parallel_rank}" + if "tcp" in endpoint: + if endpoint and ":" in endpoint: + # Get everything after the last colon (the port) + last_colon_idx = endpoint.rfind(":") + base_addr = endpoint[:last_colon_idx] + base_port = int(endpoint[last_colon_idx + 1:]) + new_port = base_port + data_parallel_rank + return f"{base_addr}:{new_port}" + return endpoint + raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'") + class EventPublisherFactory: _registry: dict[str, Callable[..., EventPublisher]] = { @@ -281,7 +337,9 @@ class EventPublisherFactory: cls._registry[name] = ctor @classmethod - def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher: + def create(cls, + config: Optional[KVEventsConfig], + data_parallel_rank: int = 0) -> EventPublisher: """Create publisher from a config mapping.""" if not config: return NullEventPublisher() @@ -294,4 +352,5 @@ class EventPublisherFactory: constructor = cls._registry[kind] except KeyError as exc: raise ValueError(f"Unknown event publisher '{kind}'") from exc - return constructor(**config_dict) + return constructor(data_parallel_rank=data_parallel_rank, + **config_dict) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e510a0626c1b4..32d03b311a4ed 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface): config=self.vllm_config, role=KVConnectorRole.SCHEDULER) self.kv_event_publisher = EventPublisherFactory.create( - self.kv_events_config) + self.kv_events_config, + vllm_config.parallel_config.data_parallel_rank, + ) num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0