feat: add data parallel rank to KVEventBatch (#18925)

This commit is contained in:
Yan Ru Pei 2025-06-03 17:14:20 -07:00 committed by GitHub
parent a8da78eac9
commit b712be98c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 359 additions and 83 deletions

View File

@ -145,6 +145,7 @@ steps:
- examples/offline_inference/rlhf_colocate.py - examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py - tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py - tests/v1/test_async_llm_dp.py
- tests/v1/engine/test_engine_core_client.py
commands: commands:
# test with tp=2 and external_dp=2 # test with tp=2 and external_dp=2
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
@ -154,6 +155,7 @@ steps:
# test with internal dp # test with internal dp
- python3 ../examples/offline_inference/data_parallel.py - python3 ../examples/offline_inference/data_parallel.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py

View File

@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
from .test_events import SampleBatch from .test_events import SampleBatch
DP_RANK = 0
@pytest.fixture @pytest.fixture
def random_port(): def random_port():
"""Generate a random port number for testing""" """Generate a random port number for testing"""
return random.randint(10000, 60000) return random.randint(10000, 59900)
@pytest.fixture @pytest.fixture
@ -30,21 +32,23 @@ def publisher_config(random_port, request):
replay_endpoint = endpoint + "-replay" replay_endpoint = endpoint + "-replay"
else: else:
endpoint = f"tcp://*:{random_port}" endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}" replay_endpoint = f"tcp://*:{random_port + 100}"
return KVEventsConfig(enable_kv_cache_events=True, return KVEventsConfig(
publisher="zmq", enable_kv_cache_events=True,
endpoint=endpoint, publisher="zmq",
replay_endpoint=replay_endpoint, endpoint=endpoint,
buffer_steps=100, replay_endpoint=replay_endpoint,
hwm=1000, buffer_steps=100,
topic="test") hwm=1000,
topic="test",
)
@pytest.fixture @pytest.fixture
def publisher(publisher_config): def publisher(publisher_config):
"""Create and return a publisher instance""" """Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config) pub = EventPublisherFactory.create(publisher_config, DP_RANK)
yield pub yield pub
pub.shutdown() pub.shutdown()
@ -60,7 +64,11 @@ def subscriber(publisher_config):
if replay_endpoint and replay_endpoint.startswith("tcp://*"): if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1") replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic) sub = MockSubscriber(
[endpoint],
[replay_endpoint] if replay_endpoint else None,
publisher_config.topic,
)
yield sub yield sub
sub.close() sub.close()
@ -68,26 +76,37 @@ def subscriber(publisher_config):
class MockSubscriber: class MockSubscriber:
"""Helper class to receive and verify published events""" """Helper class to receive and verify published events"""
def __init__(self, def __init__(
pub_endpoint: str, self,
replay_endpoint: Optional[str] = None, pub_endpoints: Union[str, list[str]],
topic: str = "", replay_endpoints: Optional[Union[str, list[str]]] = None,
decode_type=SampleBatch): topic: str = "",
decode_type=SampleBatch,
):
self.ctx = zmq.Context.instance() self.ctx = zmq.Context.instance()
# Set up subscriber socket # Convert single endpoint to list for consistency
self.sub = self.ctx.socket(zmq.SUB) if isinstance(pub_endpoints, str):
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8')) pub_endpoints = [pub_endpoints]
self.sub.connect(pub_endpoint) if isinstance(replay_endpoints, str):
replay_endpoints = [replay_endpoints]
# Set up replay socket if provided # Set up subscriber socket - connect to all endpoints
self.replay = None self.sub = self.ctx.socket(zmq.SUB)
if replay_endpoint: self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
self.replay = self.ctx.socket(zmq.REQ) for endpoint in pub_endpoints:
self.replay.connect(replay_endpoint) self.sub.connect(endpoint)
# Set up replay sockets if provided
self.replay_sockets = []
if replay_endpoints:
for replay_endpoint in replay_endpoints:
replay = self.ctx.socket(zmq.REQ)
replay.connect(replay_endpoint)
self.replay_sockets.append(replay)
self.topic = topic self.topic = topic
self.topic_bytes = topic.encode('utf-8') self.topic_bytes = topic.encode("utf-8")
self.received_msgs: list[tuple[int, SampleBatch]] = [] self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1 self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type) self.decoder = msgspec.msgpack.Decoder(type=decode_type)
@ -107,25 +126,31 @@ class MockSubscriber:
self.received_msgs.append((seq, data)) self.received_msgs.append((seq, data))
return seq, data return seq, data
def request_replay(self, start_seq: int) -> None: def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
"""Request replay of messages starting from start_seq""" """Request replay of messages starting from start_seq"""
if not self.replay: if not self.replay_sockets:
raise ValueError("Replay socket not initialized") raise ValueError("Replay sockets not initialized")
if socket_idx >= len(self.replay_sockets):
raise ValueError(f"Invalid socket index {socket_idx}")
self.replay.send(start_seq.to_bytes(8, "big")) self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
def receive_replay(self) -> list[tuple[int, SampleBatch]]: def receive_replay(self,
"""Receive replayed messages""" socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
if not self.replay: """Receive replayed messages from a specific replay socket"""
raise ValueError("Replay socket not initialized") if not self.replay_sockets:
raise ValueError("Replay sockets not initialized")
if socket_idx >= len(self.replay_sockets):
raise ValueError(f"Invalid socket index {socket_idx}")
replay_socket = self.replay_sockets[socket_idx]
replayed: list[tuple[int, SampleBatch]] = [] replayed: list[tuple[int, SampleBatch]] = []
while True: while True:
try: try:
if not self.replay.poll(1000): if not replay_socket.poll(1000):
break break
frames = self.replay.recv_multipart() frames = replay_socket.recv_multipart()
if not frames or not frames[-1]: if not frames or not frames[-1]:
# End of replay marker # End of replay marker
break break
@ -142,5 +167,5 @@ class MockSubscriber:
def close(self): def close(self):
"""Clean up resources""" """Clean up resources"""
self.sub.close() self.sub.close()
if self.replay: for replay in self.replay_sockets:
self.replay.close() replay.close()

View File

@ -9,6 +9,8 @@ import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher) NullEventPublisher)
DP_RANK = 0
class EventSample( class EventSample(
msgspec.Struct, msgspec.Struct,
@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
publisher_config.replay_endpoint = None publisher_config.replay_endpoint = None
publisher_config.topic = "foo" publisher_config.topic = "foo"
pub = EventPublisherFactory.create(publisher_config) pub = EventPublisherFactory.create(publisher_config, DP_RANK)
from .conftest import MockSubscriber from .conftest import MockSubscriber
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo") sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
def test_null_publisher(): def test_null_publisher():
"""Test that NullEventPublisher can be used without errors""" """Test that NullEventPublisher can be used without errors"""
publisher = NullEventPublisher() publisher = NullEventPublisher(DP_RANK)
# This should not raise any errors # This should not raise any errors
batch = create_test_events(5) batch = create_test_events(5)
publisher.publish(batch) publisher.publish(batch)
publisher.shutdown() publisher.shutdown()
def test_data_parallel_rank_tagging(publisher_config):
"""Test that events are properly tagged with their data parallel rank"""
publisher_config.topic = "foo"
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
# Hardcode the expected endpoints based on port offsetting behavior
# Both ranks get offsets according to _offset_endpoint_port function
base_endpoint = publisher_config.endpoint
if "tcp://" in base_endpoint:
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
expected_endpoint_1 = base_endpoint.replace(
":5557", ":5558") # rank 1 gets port + 1
else:
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
expected_endpoint_0 = base_endpoint # rank 0 gets base
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
from .conftest import MockSubscriber
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
try:
time.sleep(0.1) # Let publishers start up
# Publish events from different ranks
batch_0 = create_test_events(2)
batch_1 = create_test_events(3)
pub_0.publish(batch_0)
pub_1.publish(batch_1)
# Receive events from rank 0
result_0 = sub_0.receive_one(timeout=200)
assert result_0 is not None, "No message received from rank 0"
seq_0, received_0 = result_0
# Receive events from rank 1
result_1 = sub_1.receive_one(timeout=200)
assert result_1 is not None, "No message received from rank 1"
seq_1, received_1 = result_1
# Verify DP rank tagging
assert received_0.data_parallel_rank == 0, (
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
assert received_1.data_parallel_rank == 1, (
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
# Verify event content is correct
assert len(
received_0.events) == 2, "Wrong number of events from rank 0"
assert len(
received_1.events) == 3, "Wrong number of events from rank 1"
finally:
pub_0.shutdown()
pub_1.shutdown()
sub_0.close()
sub_1.close()

View File

@ -12,8 +12,10 @@ from typing import Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
from vllm import SamplingParams from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
ZmqEventPublisher)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest: def make_request(
params: SamplingParams,
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS
return EngineCoreRequest( return EngineCoreRequest(
request_id=str(uuid.uuid4()), request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS, prompt_token_ids=prompt_tokens_ids,
mm_inputs=None, mm_inputs=None,
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
break break
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0:
continue
# Add outputs to the dict
for out in engine_core_outputs:
outputs[out.request_id].append(out)
# Check if all request IDs in outputs have finished
if all(outs and outs[-1].finished for outs in outputs.values()):
break
await asyncio.sleep(0.1)
# Dummy utility function to monkey-patch into engine core. # Dummy utility function to monkey-patch into engine core.
def echo(self, msg: str, err_msg: Optional[str] = None) -> str: def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
print(f"echo util function called: {msg}, {err_msg}") print(f"echo util function called: {msg}, {err_msg}")
@ -273,10 +299,12 @@ def test_kv_cache_events(
block_size = 16 block_size = 16
num_blocks = 2 num_blocks = 2
engine_args = EngineArgs(model=MODEL_NAME, engine_args = EngineArgs(
enforce_eager=True, model=MODEL_NAME,
enable_prefix_caching=True, enforce_eager=True,
block_size=block_size) enable_prefix_caching=True,
block_size=block_size,
)
engine_args.kv_events_config = publisher_config engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
@ -297,19 +325,8 @@ def test_kv_cache_events(
try: try:
custom_tokens = list(range(num_blocks * block_size)) custom_tokens = list(range(num_blocks * block_size))
request = EngineCoreRequest( sampling_params = SamplingParams(max_tokens=1)
request_id=str(uuid.uuid4()), request = make_request(sampling_params, custom_tokens)
prompt_token_ids=custom_tokens,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(
max_tokens=1), # Short completion for speed
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
)
client.add_request(request) client.add_request(request)
outputs: dict[str, list] = {request.request_id: []} outputs: dict[str, list] = {request.request_id: []}
@ -321,24 +338,130 @@ def test_kv_cache_events(
seq, received = result seq, received = result
assert seq == 0, "Sequence number mismatch" assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, ( assert (len(received.events) == 1
"We should have exactly one BlockStored event") ), "We should have exactly one BlockStored event"
event = received.events[0] event = received.events[0]
assert isinstance( assert isinstance(
event, BlockStored), ("We should have a BlockStored event") event, BlockStored), "We should have a BlockStored event"
assert len(event.block_hashes) == num_blocks, ( assert (len(event.block_hashes) == num_blocks
"We should have a BlockStored event with 2 block_hashes") ), "We should have a BlockStored event with 2 block_hashes"
assert event.block_size == block_size, ( assert (event.block_size == block_size
"Block size should be the same as the block size") ), "Block size should be the same as the block size"
assert event.parent_block_hash is None, ( assert (event.parent_block_hash
"Parent block hash should be None") is None), "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None" assert event.lora_id is None, "Lora id should be None"
assert len(event.token_ids) == num_blocks * block_size, ( assert (len(event.token_ids) == num_blocks * block_size
"Token ids should be the same as the custom tokens") ), "Token ids should be the same as the custom tokens"
assert event.token_ids == custom_tokens, ( assert (event.token_ids == custom_tokens
"Token ids should be the same as the custom tokens") ), "Token ids should be the same as the custom tokens"
finally: finally:
client.shutdown() client.shutdown()
subscriber.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp")],
indirect=["publisher_config"],
)
@multi_gpu_test(num_gpus=4)
async def test_kv_cache_events_dp(
monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
num_blocks = 2
dp_size = 2
tp_size = 2
engine_args = EngineArgs(
model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=dp_size,
tensor_parallel_size=tp_size,
block_size=block_size,
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
await asyncio.sleep(1)
# Build endpoints for all DP ranks
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
endpoints = []
for i in range(dp_size):
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
base_endpoint, i)
endpoints.append(offset_endpoint)
subscriber = MockSubscriber(endpoints,
topic=publisher_config.topic,
decode_type=KVEventBatch)
try:
custom_tokens = list(range(num_blocks * block_size))
sampling_params = SamplingParams(max_tokens=1)
all_request_ids = []
# Create and add 25 requests
# NOTE: attempts to force routing to both dp groups but can be flaky
for i in range(25):
await asyncio.sleep(0.01)
request = make_request(sampling_params, custom_tokens)
await client.add_request_async(request)
all_request_ids.append(request.request_id)
await asyncio.sleep(0.1)
# Initialize outputs dict for all requests
outputs: dict[str, list] = {
req_id: []
for req_id in all_request_ids
}
print("processing requests...")
await asyncio.wait_for(loop_until_fully_done_async(
client, outputs),
timeout=20.0)
# Receive from subscriber until no more messages
print("collecting results...")
results = []
while True:
result = subscriber.receive_one(timeout=1)
print(result)
if result is None:
break
results.append(result)
# Collect all events and data_parallel_ranks from all results
all_dp_ranks = [
received.data_parallel_rank for (_, received) in results
]
unique_dps = set(all_dp_ranks)
assert (
len(unique_dps) == 2
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
finally:
client.shutdown()
subscriber.close()
@pytest.mark.timeout(20) @pytest.mark.timeout(20)

View File

@ -28,6 +28,7 @@ class EventBatch(
): ):
ts: float ts: float
events: list[Any] events: list[Any]
data_parallel_rank: Optional[int] = None
class KVCacheEvent( class KVCacheEvent(
@ -60,7 +61,22 @@ class KVEventBatch(EventBatch):
class EventPublisher(ABC): class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches.""" """Lightweight publisher for EventBatch batches with data parallelism
support.
In data parallel setups, each DP rank runs its own EventPublisher instance
to avoid duplicate events and ensure proper event attribution:
- Each DP rank creates a separate publisher
- Publishers automatically annotate events with their data_parallel_rank
- This allows consumers to distinguish events from different DP ranks
The publisher is responsible for adding DP metadata since the scheduler
operates independently of DP topology and shouldn't need DP awareness.
"""
def __init__(self, data_parallel_rank: int = 0) -> None:
self._data_parallel_rank = data_parallel_rank
@abstractmethod @abstractmethod
def publish(self, events: EventBatch) -> None: def publish(self, events: EventBatch) -> None:
@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
def __init__( def __init__(
self, self,
data_parallel_rank: int,
endpoint: str = "tcp://*:5557", endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None, replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000, buffer_steps: int = 10_000,
@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
topic: str = "", topic: str = "",
) -> None: ) -> None:
# Storage # Storage
super().__init__(data_parallel_rank)
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
self._ctx = zmq.Context.instance() self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint self._dp_rank = data_parallel_rank
self._replay_endpoint = replay_endpoint
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
self._replay_endpoint = self.offset_endpoint_port(
replay_endpoint, self._dp_rank)
self._hwm = hwm self._hwm = hwm
self._socket_setup() self._socket_setup()
@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
def publish(self, events: EventBatch) -> None: def publish(self, events: EventBatch) -> None:
if not self._running: if not self._running:
raise RuntimeError("Publisher is closed") raise RuntimeError("Publisher is closed")
if events.data_parallel_rank is None:
events.data_parallel_rank = self._data_parallel_rank
self._event_queue.put(events) self._event_queue.put(events)
def shutdown(self) -> None: def shutdown(self) -> None:
@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
self._pub.set_hwm(self._hwm) self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect. # Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention # bind stable, connect volatile convention
if ("*" in self._endpoint or "::" in self._endpoint if (self._endpoint is not None
or self._endpoint.startswith("ipc://") and ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("inproc://")): or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://"))):
self._pub.bind(self._endpoint) self._pub.bind(self._endpoint)
else: elif self._endpoint is not None:
self._pub.connect(self._endpoint) self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER # Set up replay socket: use ROUTER
@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
# receiving payload is (-1, b""") # receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
@staticmethod
def offset_endpoint_port(endpoint: Optional[str],
data_parallel_rank: int) -> Optional[str]:
"""Helper function to offset the port in an endpoint by
the data parallel rank.
Args:
endpoint: The endpoint string
(e.g., "tcp://*:5557" or "inproc://cache")
data_parallel_rank: The data parallel rank to offset by
Returns:
The endpoint with the port offset by data_parallel_rank
or suffix appended
"""
# Do nothing if input is None or data_parallel_rank is 0
if not endpoint or data_parallel_rank == 0:
return endpoint
if "inproc" in endpoint:
return f"{endpoint}_dp{data_parallel_rank}"
if "tcp" in endpoint:
if endpoint and ":" in endpoint:
# Get everything after the last colon (the port)
last_colon_idx = endpoint.rfind(":")
base_addr = endpoint[:last_colon_idx]
base_port = int(endpoint[last_colon_idx + 1:])
new_port = base_port + data_parallel_rank
return f"{base_addr}:{new_port}"
return endpoint
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
class EventPublisherFactory: class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = { _registry: dict[str, Callable[..., EventPublisher]] = {
@ -281,7 +337,9 @@ class EventPublisherFactory:
cls._registry[name] = ctor cls._registry[name] = ctor
@classmethod @classmethod
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher: def create(cls,
config: Optional[KVEventsConfig],
data_parallel_rank: int = 0) -> EventPublisher:
"""Create publisher from a config mapping.""" """Create publisher from a config mapping."""
if not config: if not config:
return NullEventPublisher() return NullEventPublisher()
@ -294,4 +352,5 @@ class EventPublisherFactory:
constructor = cls._registry[kind] constructor = cls._registry[kind]
except KeyError as exc: except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict) return constructor(data_parallel_rank=data_parallel_rank,
**config_dict)

View File

@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface):
config=self.vllm_config, role=KVConnectorRole.SCHEDULER) config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create( self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config) self.kv_events_config,
vllm_config.parallel_config.data_parallel_rank,
)
num_gpu_blocks = self.cache_config.num_gpu_blocks num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0 assert num_gpu_blocks is not None and num_gpu_blocks > 0