mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 13:47:18 +08:00
feat: add data parallel rank to KVEventBatch (#18925)
This commit is contained in:
parent
a8da78eac9
commit
b712be98c7
@ -145,6 +145,7 @@ steps:
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
@ -154,6 +155,7 @@ steps:
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
|
||||
@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
|
||||
|
||||
from .test_events import SampleBatch
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_port():
|
||||
"""Generate a random port number for testing"""
|
||||
return random.randint(10000, 60000)
|
||||
return random.randint(10000, 59900)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -30,21 +32,23 @@ def publisher_config(random_port, request):
|
||||
replay_endpoint = endpoint + "-replay"
|
||||
else:
|
||||
endpoint = f"tcp://*:{random_port}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 1}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 100}"
|
||||
|
||||
return KVEventsConfig(enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test")
|
||||
return KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher(publisher_config):
|
||||
"""Create and return a publisher instance"""
|
||||
pub = EventPublisherFactory.create(publisher_config)
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
yield pub
|
||||
pub.shutdown()
|
||||
|
||||
@ -60,7 +64,11 @@ def subscriber(publisher_config):
|
||||
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||
|
||||
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
|
||||
sub = MockSubscriber(
|
||||
[endpoint],
|
||||
[replay_endpoint] if replay_endpoint else None,
|
||||
publisher_config.topic,
|
||||
)
|
||||
yield sub
|
||||
sub.close()
|
||||
|
||||
@ -68,26 +76,37 @@ def subscriber(publisher_config):
|
||||
class MockSubscriber:
|
||||
"""Helper class to receive and verify published events"""
|
||||
|
||||
def __init__(self,
|
||||
pub_endpoint: str,
|
||||
replay_endpoint: Optional[str] = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch):
|
||||
def __init__(
|
||||
self,
|
||||
pub_endpoints: Union[str, list[str]],
|
||||
replay_endpoints: Optional[Union[str, list[str]]] = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch,
|
||||
):
|
||||
self.ctx = zmq.Context.instance()
|
||||
|
||||
# Set up subscriber socket
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
|
||||
self.sub.connect(pub_endpoint)
|
||||
# Convert single endpoint to list for consistency
|
||||
if isinstance(pub_endpoints, str):
|
||||
pub_endpoints = [pub_endpoints]
|
||||
if isinstance(replay_endpoints, str):
|
||||
replay_endpoints = [replay_endpoints]
|
||||
|
||||
# Set up replay socket if provided
|
||||
self.replay = None
|
||||
if replay_endpoint:
|
||||
self.replay = self.ctx.socket(zmq.REQ)
|
||||
self.replay.connect(replay_endpoint)
|
||||
# Set up subscriber socket - connect to all endpoints
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
|
||||
for endpoint in pub_endpoints:
|
||||
self.sub.connect(endpoint)
|
||||
|
||||
# Set up replay sockets if provided
|
||||
self.replay_sockets = []
|
||||
if replay_endpoints:
|
||||
for replay_endpoint in replay_endpoints:
|
||||
replay = self.ctx.socket(zmq.REQ)
|
||||
replay.connect(replay_endpoint)
|
||||
self.replay_sockets.append(replay)
|
||||
|
||||
self.topic = topic
|
||||
self.topic_bytes = topic.encode('utf-8')
|
||||
self.topic_bytes = topic.encode("utf-8")
|
||||
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||
self.last_seq = -1
|
||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||
@ -107,25 +126,31 @@ class MockSubscriber:
|
||||
self.received_msgs.append((seq, data))
|
||||
return seq, data
|
||||
|
||||
def request_replay(self, start_seq: int) -> None:
|
||||
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
|
||||
"""Request replay of messages starting from start_seq"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
if socket_idx >= len(self.replay_sockets):
|
||||
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||
|
||||
self.replay.send(start_seq.to_bytes(8, "big"))
|
||||
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||
|
||||
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
def receive_replay(self,
|
||||
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages from a specific replay socket"""
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
if socket_idx >= len(self.replay_sockets):
|
||||
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||
|
||||
replay_socket = self.replay_sockets[socket_idx]
|
||||
replayed: list[tuple[int, SampleBatch]] = []
|
||||
while True:
|
||||
try:
|
||||
if not self.replay.poll(1000):
|
||||
if not replay_socket.poll(1000):
|
||||
break
|
||||
|
||||
frames = self.replay.recv_multipart()
|
||||
frames = replay_socket.recv_multipart()
|
||||
if not frames or not frames[-1]:
|
||||
# End of replay marker
|
||||
break
|
||||
@ -142,5 +167,5 @@ class MockSubscriber:
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
self.sub.close()
|
||||
if self.replay:
|
||||
self.replay.close()
|
||||
for replay in self.replay_sockets:
|
||||
replay.close()
|
||||
|
||||
@ -9,6 +9,8 @@ import pytest
|
||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||
NullEventPublisher)
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
class EventSample(
|
||||
msgspec.Struct,
|
||||
@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
|
||||
publisher_config.replay_endpoint = None
|
||||
|
||||
publisher_config.topic = "foo"
|
||||
pub = EventPublisherFactory.create(publisher_config)
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||
@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
|
||||
|
||||
def test_null_publisher():
|
||||
"""Test that NullEventPublisher can be used without errors"""
|
||||
publisher = NullEventPublisher()
|
||||
publisher = NullEventPublisher(DP_RANK)
|
||||
|
||||
# This should not raise any errors
|
||||
batch = create_test_events(5)
|
||||
publisher.publish(batch)
|
||||
publisher.shutdown()
|
||||
|
||||
|
||||
def test_data_parallel_rank_tagging(publisher_config):
|
||||
"""Test that events are properly tagged with their data parallel rank"""
|
||||
|
||||
publisher_config.topic = "foo"
|
||||
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
|
||||
|
||||
# Hardcode the expected endpoints based on port offsetting behavior
|
||||
# Both ranks get offsets according to _offset_endpoint_port function
|
||||
base_endpoint = publisher_config.endpoint
|
||||
if "tcp://" in base_endpoint:
|
||||
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||
expected_endpoint_1 = base_endpoint.replace(
|
||||
":5557", ":5558") # rank 1 gets port + 1
|
||||
else:
|
||||
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||
|
||||
try:
|
||||
time.sleep(0.1) # Let publishers start up
|
||||
|
||||
# Publish events from different ranks
|
||||
batch_0 = create_test_events(2)
|
||||
batch_1 = create_test_events(3)
|
||||
|
||||
pub_0.publish(batch_0)
|
||||
pub_1.publish(batch_1)
|
||||
|
||||
# Receive events from rank 0
|
||||
result_0 = sub_0.receive_one(timeout=200)
|
||||
assert result_0 is not None, "No message received from rank 0"
|
||||
seq_0, received_0 = result_0
|
||||
|
||||
# Receive events from rank 1
|
||||
result_1 = sub_1.receive_one(timeout=200)
|
||||
assert result_1 is not None, "No message received from rank 1"
|
||||
seq_1, received_1 = result_1
|
||||
|
||||
# Verify DP rank tagging
|
||||
assert received_0.data_parallel_rank == 0, (
|
||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
|
||||
assert received_1.data_parallel_rank == 1, (
|
||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
|
||||
|
||||
# Verify event content is correct
|
||||
assert len(
|
||||
received_0.events) == 2, "Wrong number of events from rank 0"
|
||||
assert len(
|
||||
received_1.events) == 3, "Wrong number of events from rank 1"
|
||||
|
||||
finally:
|
||||
pub_0.shutdown()
|
||||
pub_1.shutdown()
|
||||
sub_0.close()
|
||||
sub_1.close()
|
||||
|
||||
@ -12,8 +12,10 @@ from typing import Optional
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
|
||||
ZmqEventPublisher)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request(params: SamplingParams) -> EngineCoreRequest:
|
||||
def make_request(
|
||||
params: SamplingParams,
|
||||
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
|
||||
if not prompt_tokens_ids:
|
||||
prompt_tokens_ids = PROMPT_TOKENS
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
prompt_token_ids=prompt_tokens_ids,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
|
||||
break
|
||||
|
||||
|
||||
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
continue
|
||||
|
||||
# Add outputs to the dict
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
|
||||
# Check if all request IDs in outputs have finished
|
||||
if all(outs and outs[-1].finished for outs in outputs.values()):
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
|
||||
print(f"echo util function called: {msg}, {err_msg}")
|
||||
@ -273,10 +299,12 @@ def test_kv_cache_events(
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size)
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size,
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
@ -297,19 +325,8 @@ def test_kv_cache_events(
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
request = EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=custom_tokens,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=1), # Short completion for speed
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
request = make_request(sampling_params, custom_tokens)
|
||||
client.add_request(request)
|
||||
|
||||
outputs: dict[str, list] = {request.request_id: []}
|
||||
@ -321,24 +338,130 @@ def test_kv_cache_events(
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, (
|
||||
"We should have exactly one BlockStored event")
|
||||
assert (len(received.events) == 1
|
||||
), "We should have exactly one BlockStored event"
|
||||
event = received.events[0]
|
||||
assert isinstance(
|
||||
event, BlockStored), ("We should have a BlockStored event")
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes")
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size")
|
||||
assert event.parent_block_hash is None, (
|
||||
"Parent block hash should be None")
|
||||
event, BlockStored), "We should have a BlockStored event"
|
||||
assert (len(event.block_hashes) == num_blocks
|
||||
), "We should have a BlockStored event with 2 block_hashes"
|
||||
assert (event.block_size == block_size
|
||||
), "Block size should be the same as the block size"
|
||||
assert (event.parent_block_hash
|
||||
is None), "Parent block hash should be None"
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert (len(event.token_ids) == num_blocks * block_size
|
||||
), "Token ids should be the same as the custom tokens"
|
||||
assert (event.token_ids == custom_tokens
|
||||
), "Token ids should be the same as the custom tokens"
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
async def test_kv_cache_events_dp(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
dp_size = 2
|
||||
tp_size = 2
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
data_parallel_size=dp_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
block_size=block_size,
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Build endpoints for all DP ranks
|
||||
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
endpoints = []
|
||||
for i in range(dp_size):
|
||||
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
|
||||
base_endpoint, i)
|
||||
endpoints.append(offset_endpoint)
|
||||
|
||||
subscriber = MockSubscriber(endpoints,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
all_request_ids = []
|
||||
|
||||
# Create and add 25 requests
|
||||
# NOTE: attempts to force routing to both dp groups but can be flaky
|
||||
for i in range(25):
|
||||
await asyncio.sleep(0.01)
|
||||
request = make_request(sampling_params, custom_tokens)
|
||||
await client.add_request_async(request)
|
||||
all_request_ids.append(request.request_id)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Initialize outputs dict for all requests
|
||||
outputs: dict[str, list] = {
|
||||
req_id: []
|
||||
for req_id in all_request_ids
|
||||
}
|
||||
|
||||
print("processing requests...")
|
||||
await asyncio.wait_for(loop_until_fully_done_async(
|
||||
client, outputs),
|
||||
timeout=20.0)
|
||||
|
||||
# Receive from subscriber until no more messages
|
||||
print("collecting results...")
|
||||
results = []
|
||||
while True:
|
||||
result = subscriber.receive_one(timeout=1)
|
||||
print(result)
|
||||
if result is None:
|
||||
break
|
||||
results.append(result)
|
||||
|
||||
# Collect all events and data_parallel_ranks from all results
|
||||
all_dp_ranks = [
|
||||
received.data_parallel_rank for (_, received) in results
|
||||
]
|
||||
unique_dps = set(all_dp_ranks)
|
||||
assert (
|
||||
len(unique_dps) == 2
|
||||
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.timeout(20)
|
||||
|
||||
@ -28,6 +28,7 @@ class EventBatch(
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
@ -60,7 +61,22 @@ class KVEventBatch(EventBatch):
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches."""
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
support.
|
||||
|
||||
In data parallel setups, each DP rank runs its own EventPublisher instance
|
||||
to avoid duplicate events and ensure proper event attribution:
|
||||
|
||||
- Each DP rank creates a separate publisher
|
||||
- Publishers automatically annotate events with their data_parallel_rank
|
||||
- This allows consumers to distinguish events from different DP ranks
|
||||
|
||||
The publisher is responsible for adding DP metadata since the scheduler
|
||||
operates independently of DP topology and shouldn't need DP awareness.
|
||||
"""
|
||||
|
||||
def __init__(self, data_parallel_rank: int = 0) -> None:
|
||||
self._data_parallel_rank = data_parallel_rank
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_parallel_rank: int,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: Optional[str] = None,
|
||||
buffer_steps: int = 10_000,
|
||||
@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
super().__init__(data_parallel_rank)
|
||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: Optional[zmq.Socket] = None
|
||||
self._replay: Optional[zmq.Socket] = None
|
||||
self._endpoint = endpoint
|
||||
self._replay_endpoint = replay_endpoint
|
||||
self._dp_rank = data_parallel_rank
|
||||
|
||||
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||
self._replay_endpoint = self.offset_endpoint_port(
|
||||
replay_endpoint, self._dp_rank)
|
||||
self._hwm = hwm
|
||||
self._socket_setup()
|
||||
|
||||
@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
if events.data_parallel_rank is None:
|
||||
events.data_parallel_rank = self._data_parallel_rank
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
|
||||
self._pub.set_hwm(self._hwm)
|
||||
# Heuristic: bind if wildcard / * present, else connect.
|
||||
# bind stable, connect volatile convention
|
||||
if ("*" in self._endpoint or "::" in self._endpoint
|
||||
or self._endpoint.startswith("ipc://")
|
||||
or self._endpoint.startswith("inproc://")):
|
||||
if (self._endpoint is not None
|
||||
and ("*" in self._endpoint or "::" in self._endpoint
|
||||
or self._endpoint.startswith("ipc://")
|
||||
or self._endpoint.startswith("inproc://"))):
|
||||
self._pub.bind(self._endpoint)
|
||||
else:
|
||||
elif self._endpoint is not None:
|
||||
self._pub.connect(self._endpoint)
|
||||
|
||||
# Set up replay socket: use ROUTER
|
||||
@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
@staticmethod
|
||||
def offset_endpoint_port(endpoint: Optional[str],
|
||||
data_parallel_rank: int) -> Optional[str]:
|
||||
"""Helper function to offset the port in an endpoint by
|
||||
the data parallel rank.
|
||||
|
||||
Args:
|
||||
endpoint: The endpoint string
|
||||
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||
data_parallel_rank: The data parallel rank to offset by
|
||||
|
||||
Returns:
|
||||
The endpoint with the port offset by data_parallel_rank
|
||||
or suffix appended
|
||||
"""
|
||||
# Do nothing if input is None or data_parallel_rank is 0
|
||||
if not endpoint or data_parallel_rank == 0:
|
||||
return endpoint
|
||||
|
||||
if "inproc" in endpoint:
|
||||
return f"{endpoint}_dp{data_parallel_rank}"
|
||||
if "tcp" in endpoint:
|
||||
if endpoint and ":" in endpoint:
|
||||
# Get everything after the last colon (the port)
|
||||
last_colon_idx = endpoint.rfind(":")
|
||||
base_addr = endpoint[:last_colon_idx]
|
||||
base_port = int(endpoint[last_colon_idx + 1:])
|
||||
new_port = base_port + data_parallel_rank
|
||||
return f"{base_addr}:{new_port}"
|
||||
return endpoint
|
||||
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||
|
||||
|
||||
class EventPublisherFactory:
|
||||
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||
@ -281,7 +337,9 @@ class EventPublisherFactory:
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
|
||||
def create(cls,
|
||||
config: Optional[KVEventsConfig],
|
||||
data_parallel_rank: int = 0) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
@ -294,4 +352,5 @@ class EventPublisherFactory:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(**config_dict)
|
||||
return constructor(data_parallel_rank=data_parallel_rank,
|
||||
**config_dict)
|
||||
|
||||
@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface):
|
||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
self.kv_events_config)
|
||||
self.kv_events_config,
|
||||
vllm_config.parallel_config.data_parallel_rank,
|
||||
)
|
||||
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user