feat: add data parallel rank to KVEventBatch (#18925)

This commit is contained in:
Yan Ru Pei 2025-06-03 17:14:20 -07:00 committed by GitHub
parent a8da78eac9
commit b712be98c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 359 additions and 83 deletions

View File

@ -145,6 +145,7 @@ steps:
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py
- tests/v1/engine/test_engine_core_client.py
commands:
# test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
@ -154,6 +155,7 @@ steps:
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py

View File

@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
from .test_events import SampleBatch
DP_RANK = 0
@pytest.fixture
def random_port():
"""Generate a random port number for testing"""
return random.randint(10000, 60000)
return random.randint(10000, 59900)
@pytest.fixture
@ -30,21 +32,23 @@ def publisher_config(random_port, request):
replay_endpoint = endpoint + "-replay"
else:
endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}"
replay_endpoint = f"tcp://*:{random_port + 100}"
return KVEventsConfig(enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test")
return KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test",
)
@pytest.fixture
def publisher(publisher_config):
"""Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config)
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
yield pub
pub.shutdown()
@ -60,7 +64,11 @@ def subscriber(publisher_config):
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
sub = MockSubscriber(
[endpoint],
[replay_endpoint] if replay_endpoint else None,
publisher_config.topic,
)
yield sub
sub.close()
@ -68,26 +76,37 @@ def subscriber(publisher_config):
class MockSubscriber:
"""Helper class to receive and verify published events"""
def __init__(self,
pub_endpoint: str,
replay_endpoint: Optional[str] = None,
topic: str = "",
decode_type=SampleBatch):
def __init__(
self,
pub_endpoints: Union[str, list[str]],
replay_endpoints: Optional[Union[str, list[str]]] = None,
topic: str = "",
decode_type=SampleBatch,
):
self.ctx = zmq.Context.instance()
# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
self.sub.connect(pub_endpoint)
# Convert single endpoint to list for consistency
if isinstance(pub_endpoints, str):
pub_endpoints = [pub_endpoints]
if isinstance(replay_endpoints, str):
replay_endpoints = [replay_endpoints]
# Set up replay socket if provided
self.replay = None
if replay_endpoint:
self.replay = self.ctx.socket(zmq.REQ)
self.replay.connect(replay_endpoint)
# Set up subscriber socket - connect to all endpoints
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
for endpoint in pub_endpoints:
self.sub.connect(endpoint)
# Set up replay sockets if provided
self.replay_sockets = []
if replay_endpoints:
for replay_endpoint in replay_endpoints:
replay = self.ctx.socket(zmq.REQ)
replay.connect(replay_endpoint)
self.replay_sockets.append(replay)
self.topic = topic
self.topic_bytes = topic.encode('utf-8')
self.topic_bytes = topic.encode("utf-8")
self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
@ -107,25 +126,31 @@ class MockSubscriber:
self.received_msgs.append((seq, data))
return seq, data
def request_replay(self, start_seq: int) -> None:
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
"""Request replay of messages starting from start_seq"""
if not self.replay:
raise ValueError("Replay socket not initialized")
if not self.replay_sockets:
raise ValueError("Replay sockets not initialized")
if socket_idx >= len(self.replay_sockets):
raise ValueError(f"Invalid socket index {socket_idx}")
self.replay.send(start_seq.to_bytes(8, "big"))
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages"""
if not self.replay:
raise ValueError("Replay socket not initialized")
def receive_replay(self,
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages from a specific replay socket"""
if not self.replay_sockets:
raise ValueError("Replay sockets not initialized")
if socket_idx >= len(self.replay_sockets):
raise ValueError(f"Invalid socket index {socket_idx}")
replay_socket = self.replay_sockets[socket_idx]
replayed: list[tuple[int, SampleBatch]] = []
while True:
try:
if not self.replay.poll(1000):
if not replay_socket.poll(1000):
break
frames = self.replay.recv_multipart()
frames = replay_socket.recv_multipart()
if not frames or not frames[-1]:
# End of replay marker
break
@ -142,5 +167,5 @@ class MockSubscriber:
def close(self):
"""Clean up resources"""
self.sub.close()
if self.replay:
self.replay.close()
for replay in self.replay_sockets:
replay.close()

View File

@ -9,6 +9,8 @@ import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher)
DP_RANK = 0
class EventSample(
msgspec.Struct,
@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
publisher_config.replay_endpoint = None
publisher_config.topic = "foo"
pub = EventPublisherFactory.create(publisher_config)
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
from .conftest import MockSubscriber
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
def test_null_publisher():
"""Test that NullEventPublisher can be used without errors"""
publisher = NullEventPublisher()
publisher = NullEventPublisher(DP_RANK)
# This should not raise any errors
batch = create_test_events(5)
publisher.publish(batch)
publisher.shutdown()
def test_data_parallel_rank_tagging(publisher_config):
"""Test that events are properly tagged with their data parallel rank"""
publisher_config.topic = "foo"
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
# Hardcode the expected endpoints based on port offsetting behavior
# Both ranks get offsets according to _offset_endpoint_port function
base_endpoint = publisher_config.endpoint
if "tcp://" in base_endpoint:
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
expected_endpoint_1 = base_endpoint.replace(
":5557", ":5558") # rank 1 gets port + 1
else:
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
expected_endpoint_0 = base_endpoint # rank 0 gets base
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
from .conftest import MockSubscriber
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
try:
time.sleep(0.1) # Let publishers start up
# Publish events from different ranks
batch_0 = create_test_events(2)
batch_1 = create_test_events(3)
pub_0.publish(batch_0)
pub_1.publish(batch_1)
# Receive events from rank 0
result_0 = sub_0.receive_one(timeout=200)
assert result_0 is not None, "No message received from rank 0"
seq_0, received_0 = result_0
# Receive events from rank 1
result_1 = sub_1.receive_one(timeout=200)
assert result_1 is not None, "No message received from rank 1"
seq_1, received_1 = result_1
# Verify DP rank tagging
assert received_0.data_parallel_rank == 0, (
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
assert received_1.data_parallel_rank == 1, (
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
# Verify event content is correct
assert len(
received_0.events) == 2, "Wrong number of events from rank 0"
assert len(
received_1.events) == 3, "Wrong number of events from rank 1"
finally:
pub_0.shutdown()
pub_1.shutdown()
sub_0.close()
sub_1.close()

View File

@ -12,8 +12,10 @@ from typing import Optional
import pytest
from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
ZmqEventPublisher)
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest:
def make_request(
params: SamplingParams,
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
prompt_token_ids=prompt_tokens_ids,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
break
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0:
continue
# Add outputs to the dict
for out in engine_core_outputs:
outputs[out.request_id].append(out)
# Check if all request IDs in outputs have finished
if all(outs and outs[-1].finished for outs in outputs.values()):
break
await asyncio.sleep(0.1)
# Dummy utility function to monkey-patch into engine core.
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
print(f"echo util function called: {msg}, {err_msg}")
@ -273,10 +299,12 @@ def test_kv_cache_events(
block_size = 16
num_blocks = 2
engine_args = EngineArgs(model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size)
engine_args = EngineArgs(
model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size,
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
@ -297,19 +325,8 @@ def test_kv_cache_events(
try:
custom_tokens = list(range(num_blocks * block_size))
request = EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=custom_tokens,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(
max_tokens=1), # Short completion for speed
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
)
sampling_params = SamplingParams(max_tokens=1)
request = make_request(sampling_params, custom_tokens)
client.add_request(request)
outputs: dict[str, list] = {request.request_id: []}
@ -321,24 +338,130 @@ def test_kv_cache_events(
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, (
"We should have exactly one BlockStored event")
assert (len(received.events) == 1
), "We should have exactly one BlockStored event"
event = received.events[0]
assert isinstance(
event, BlockStored), ("We should have a BlockStored event")
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes")
assert event.block_size == block_size, (
"Block size should be the same as the block size")
assert event.parent_block_hash is None, (
"Parent block hash should be None")
event, BlockStored), "We should have a BlockStored event"
assert (len(event.block_hashes) == num_blocks
), "We should have a BlockStored event with 2 block_hashes"
assert (event.block_size == block_size
), "Block size should be the same as the block size"
assert (event.parent_block_hash
is None), "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens")
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens")
assert (len(event.token_ids) == num_blocks * block_size
), "Token ids should be the same as the custom tokens"
assert (event.token_ids == custom_tokens
), "Token ids should be the same as the custom tokens"
finally:
client.shutdown()
subscriber.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp")],
indirect=["publisher_config"],
)
@multi_gpu_test(num_gpus=4)
async def test_kv_cache_events_dp(
monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
num_blocks = 2
dp_size = 2
tp_size = 2
engine_args = EngineArgs(
model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=dp_size,
tensor_parallel_size=tp_size,
block_size=block_size,
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
await asyncio.sleep(1)
# Build endpoints for all DP ranks
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
endpoints = []
for i in range(dp_size):
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
base_endpoint, i)
endpoints.append(offset_endpoint)
subscriber = MockSubscriber(endpoints,
topic=publisher_config.topic,
decode_type=KVEventBatch)
try:
custom_tokens = list(range(num_blocks * block_size))
sampling_params = SamplingParams(max_tokens=1)
all_request_ids = []
# Create and add 25 requests
# NOTE: attempts to force routing to both dp groups but can be flaky
for i in range(25):
await asyncio.sleep(0.01)
request = make_request(sampling_params, custom_tokens)
await client.add_request_async(request)
all_request_ids.append(request.request_id)
await asyncio.sleep(0.1)
# Initialize outputs dict for all requests
outputs: dict[str, list] = {
req_id: []
for req_id in all_request_ids
}
print("processing requests...")
await asyncio.wait_for(loop_until_fully_done_async(
client, outputs),
timeout=20.0)
# Receive from subscriber until no more messages
print("collecting results...")
results = []
while True:
result = subscriber.receive_one(timeout=1)
print(result)
if result is None:
break
results.append(result)
# Collect all events and data_parallel_ranks from all results
all_dp_ranks = [
received.data_parallel_rank for (_, received) in results
]
unique_dps = set(all_dp_ranks)
assert (
len(unique_dps) == 2
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
finally:
client.shutdown()
subscriber.close()
@pytest.mark.timeout(20)

View File

@ -28,6 +28,7 @@ class EventBatch(
):
ts: float
events: list[Any]
data_parallel_rank: Optional[int] = None
class KVCacheEvent(
@ -60,7 +61,22 @@ class KVEventBatch(EventBatch):
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches."""
"""Lightweight publisher for EventBatch batches with data parallelism
support.
In data parallel setups, each DP rank runs its own EventPublisher instance
to avoid duplicate events and ensure proper event attribution:
- Each DP rank creates a separate publisher
- Publishers automatically annotate events with their data_parallel_rank
- This allows consumers to distinguish events from different DP ranks
The publisher is responsible for adding DP metadata since the scheduler
operates independently of DP topology and shouldn't need DP awareness.
"""
def __init__(self, data_parallel_rank: int = 0) -> None:
self._data_parallel_rank = data_parallel_rank
@abstractmethod
def publish(self, events: EventBatch) -> None:
@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
def __init__(
self,
data_parallel_rank: int,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
topic: str = "",
) -> None:
# Storage
super().__init__(data_parallel_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint
self._replay_endpoint = replay_endpoint
self._dp_rank = data_parallel_rank
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank)
self._hwm = hwm
self._socket_setup()
@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
if events.data_parallel_rank is None:
events.data_parallel_rank = self._data_parallel_rank
self._event_queue.put(events)
def shutdown(self) -> None:
@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://")):
if (self._endpoint is not None
and ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://"))):
self._pub.bind(self._endpoint)
else:
elif self._endpoint is not None:
self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER
@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(endpoint: Optional[str],
data_parallel_rank: int) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1:])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = {
@ -281,7 +337,9 @@ class EventPublisherFactory:
cls._registry[name] = ctor
@classmethod
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
def create(cls,
config: Optional[KVEventsConfig],
data_parallel_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
@ -294,4 +352,5 @@ class EventPublisherFactory:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict)
return constructor(data_parallel_rank=data_parallel_rank,
**config_dict)

View File

@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface):
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config)
self.kv_events_config,
vllm_config.parallel_config.data_parallel_rank,
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0