mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 07:47:05 +08:00
feat: add data parallel rank to KVEventBatch (#18925)
This commit is contained in:
parent
a8da78eac9
commit
b712be98c7
@ -145,6 +145,7 @@ steps:
|
|||||||
- examples/offline_inference/rlhf_colocate.py
|
- examples/offline_inference/rlhf_colocate.py
|
||||||
- tests/examples/offline_inference/data_parallel.py
|
- tests/examples/offline_inference/data_parallel.py
|
||||||
- tests/v1/test_async_llm_dp.py
|
- tests/v1/test_async_llm_dp.py
|
||||||
|
- tests/v1/engine/test_engine_core_client.py
|
||||||
commands:
|
commands:
|
||||||
# test with tp=2 and external_dp=2
|
# test with tp=2 and external_dp=2
|
||||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
@ -154,6 +155,7 @@ steps:
|
|||||||
# test with internal dp
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py
|
- python3 ../examples/offline_inference/data_parallel.py
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
|
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||||
- pytest -v -s distributed/test_utils.py
|
- pytest -v -s distributed/test_utils.py
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
|
|||||||
@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
|
|||||||
|
|
||||||
from .test_events import SampleBatch
|
from .test_events import SampleBatch
|
||||||
|
|
||||||
|
DP_RANK = 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def random_port():
|
def random_port():
|
||||||
"""Generate a random port number for testing"""
|
"""Generate a random port number for testing"""
|
||||||
return random.randint(10000, 60000)
|
return random.randint(10000, 59900)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -30,21 +32,23 @@ def publisher_config(random_port, request):
|
|||||||
replay_endpoint = endpoint + "-replay"
|
replay_endpoint = endpoint + "-replay"
|
||||||
else:
|
else:
|
||||||
endpoint = f"tcp://*:{random_port}"
|
endpoint = f"tcp://*:{random_port}"
|
||||||
replay_endpoint = f"tcp://*:{random_port + 1}"
|
replay_endpoint = f"tcp://*:{random_port + 100}"
|
||||||
|
|
||||||
return KVEventsConfig(enable_kv_cache_events=True,
|
return KVEventsConfig(
|
||||||
publisher="zmq",
|
enable_kv_cache_events=True,
|
||||||
endpoint=endpoint,
|
publisher="zmq",
|
||||||
replay_endpoint=replay_endpoint,
|
endpoint=endpoint,
|
||||||
buffer_steps=100,
|
replay_endpoint=replay_endpoint,
|
||||||
hwm=1000,
|
buffer_steps=100,
|
||||||
topic="test")
|
hwm=1000,
|
||||||
|
topic="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def publisher(publisher_config):
|
def publisher(publisher_config):
|
||||||
"""Create and return a publisher instance"""
|
"""Create and return a publisher instance"""
|
||||||
pub = EventPublisherFactory.create(publisher_config)
|
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||||
yield pub
|
yield pub
|
||||||
pub.shutdown()
|
pub.shutdown()
|
||||||
|
|
||||||
@ -60,7 +64,11 @@ def subscriber(publisher_config):
|
|||||||
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||||
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||||
|
|
||||||
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
|
sub = MockSubscriber(
|
||||||
|
[endpoint],
|
||||||
|
[replay_endpoint] if replay_endpoint else None,
|
||||||
|
publisher_config.topic,
|
||||||
|
)
|
||||||
yield sub
|
yield sub
|
||||||
sub.close()
|
sub.close()
|
||||||
|
|
||||||
@ -68,26 +76,37 @@ def subscriber(publisher_config):
|
|||||||
class MockSubscriber:
|
class MockSubscriber:
|
||||||
"""Helper class to receive and verify published events"""
|
"""Helper class to receive and verify published events"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
pub_endpoint: str,
|
self,
|
||||||
replay_endpoint: Optional[str] = None,
|
pub_endpoints: Union[str, list[str]],
|
||||||
topic: str = "",
|
replay_endpoints: Optional[Union[str, list[str]]] = None,
|
||||||
decode_type=SampleBatch):
|
topic: str = "",
|
||||||
|
decode_type=SampleBatch,
|
||||||
|
):
|
||||||
self.ctx = zmq.Context.instance()
|
self.ctx = zmq.Context.instance()
|
||||||
|
|
||||||
# Set up subscriber socket
|
# Convert single endpoint to list for consistency
|
||||||
self.sub = self.ctx.socket(zmq.SUB)
|
if isinstance(pub_endpoints, str):
|
||||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
|
pub_endpoints = [pub_endpoints]
|
||||||
self.sub.connect(pub_endpoint)
|
if isinstance(replay_endpoints, str):
|
||||||
|
replay_endpoints = [replay_endpoints]
|
||||||
|
|
||||||
# Set up replay socket if provided
|
# Set up subscriber socket - connect to all endpoints
|
||||||
self.replay = None
|
self.sub = self.ctx.socket(zmq.SUB)
|
||||||
if replay_endpoint:
|
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
|
||||||
self.replay = self.ctx.socket(zmq.REQ)
|
for endpoint in pub_endpoints:
|
||||||
self.replay.connect(replay_endpoint)
|
self.sub.connect(endpoint)
|
||||||
|
|
||||||
|
# Set up replay sockets if provided
|
||||||
|
self.replay_sockets = []
|
||||||
|
if replay_endpoints:
|
||||||
|
for replay_endpoint in replay_endpoints:
|
||||||
|
replay = self.ctx.socket(zmq.REQ)
|
||||||
|
replay.connect(replay_endpoint)
|
||||||
|
self.replay_sockets.append(replay)
|
||||||
|
|
||||||
self.topic = topic
|
self.topic = topic
|
||||||
self.topic_bytes = topic.encode('utf-8')
|
self.topic_bytes = topic.encode("utf-8")
|
||||||
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||||
self.last_seq = -1
|
self.last_seq = -1
|
||||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||||
@ -107,25 +126,31 @@ class MockSubscriber:
|
|||||||
self.received_msgs.append((seq, data))
|
self.received_msgs.append((seq, data))
|
||||||
return seq, data
|
return seq, data
|
||||||
|
|
||||||
def request_replay(self, start_seq: int) -> None:
|
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
|
||||||
"""Request replay of messages starting from start_seq"""
|
"""Request replay of messages starting from start_seq"""
|
||||||
if not self.replay:
|
if not self.replay_sockets:
|
||||||
raise ValueError("Replay socket not initialized")
|
raise ValueError("Replay sockets not initialized")
|
||||||
|
if socket_idx >= len(self.replay_sockets):
|
||||||
|
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||||
|
|
||||||
self.replay.send(start_seq.to_bytes(8, "big"))
|
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||||
|
|
||||||
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
|
def receive_replay(self,
|
||||||
"""Receive replayed messages"""
|
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||||
if not self.replay:
|
"""Receive replayed messages from a specific replay socket"""
|
||||||
raise ValueError("Replay socket not initialized")
|
if not self.replay_sockets:
|
||||||
|
raise ValueError("Replay sockets not initialized")
|
||||||
|
if socket_idx >= len(self.replay_sockets):
|
||||||
|
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||||
|
|
||||||
|
replay_socket = self.replay_sockets[socket_idx]
|
||||||
replayed: list[tuple[int, SampleBatch]] = []
|
replayed: list[tuple[int, SampleBatch]] = []
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if not self.replay.poll(1000):
|
if not replay_socket.poll(1000):
|
||||||
break
|
break
|
||||||
|
|
||||||
frames = self.replay.recv_multipart()
|
frames = replay_socket.recv_multipart()
|
||||||
if not frames or not frames[-1]:
|
if not frames or not frames[-1]:
|
||||||
# End of replay marker
|
# End of replay marker
|
||||||
break
|
break
|
||||||
@ -142,5 +167,5 @@ class MockSubscriber:
|
|||||||
def close(self):
|
def close(self):
|
||||||
"""Clean up resources"""
|
"""Clean up resources"""
|
||||||
self.sub.close()
|
self.sub.close()
|
||||||
if self.replay:
|
for replay in self.replay_sockets:
|
||||||
self.replay.close()
|
replay.close()
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import pytest
|
|||||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||||
NullEventPublisher)
|
NullEventPublisher)
|
||||||
|
|
||||||
|
DP_RANK = 0
|
||||||
|
|
||||||
|
|
||||||
class EventSample(
|
class EventSample(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
@ -121,7 +123,7 @@ def test_topic_filtering(publisher_config):
|
|||||||
publisher_config.replay_endpoint = None
|
publisher_config.replay_endpoint = None
|
||||||
|
|
||||||
publisher_config.topic = "foo"
|
publisher_config.topic = "foo"
|
||||||
pub = EventPublisherFactory.create(publisher_config)
|
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||||
|
|
||||||
from .conftest import MockSubscriber
|
from .conftest import MockSubscriber
|
||||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||||
@ -185,9 +187,72 @@ def test_high_volume(publisher, subscriber):
|
|||||||
|
|
||||||
def test_null_publisher():
|
def test_null_publisher():
|
||||||
"""Test that NullEventPublisher can be used without errors"""
|
"""Test that NullEventPublisher can be used without errors"""
|
||||||
publisher = NullEventPublisher()
|
publisher = NullEventPublisher(DP_RANK)
|
||||||
|
|
||||||
# This should not raise any errors
|
# This should not raise any errors
|
||||||
batch = create_test_events(5)
|
batch = create_test_events(5)
|
||||||
publisher.publish(batch)
|
publisher.publish(batch)
|
||||||
publisher.shutdown()
|
publisher.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_parallel_rank_tagging(publisher_config):
|
||||||
|
"""Test that events are properly tagged with their data parallel rank"""
|
||||||
|
|
||||||
|
publisher_config.topic = "foo"
|
||||||
|
pub_0 = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||||
|
pub_1 = EventPublisherFactory.create(publisher_config, DP_RANK + 1)
|
||||||
|
|
||||||
|
# Hardcode the expected endpoints based on port offsetting behavior
|
||||||
|
# Both ranks get offsets according to _offset_endpoint_port function
|
||||||
|
base_endpoint = publisher_config.endpoint
|
||||||
|
if "tcp://" in base_endpoint:
|
||||||
|
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||||
|
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||||
|
expected_endpoint_1 = base_endpoint.replace(
|
||||||
|
":5557", ":5558") # rank 1 gets port + 1
|
||||||
|
else:
|
||||||
|
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||||
|
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||||
|
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||||
|
|
||||||
|
from .conftest import MockSubscriber
|
||||||
|
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||||
|
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||||
|
|
||||||
|
try:
|
||||||
|
time.sleep(0.1) # Let publishers start up
|
||||||
|
|
||||||
|
# Publish events from different ranks
|
||||||
|
batch_0 = create_test_events(2)
|
||||||
|
batch_1 = create_test_events(3)
|
||||||
|
|
||||||
|
pub_0.publish(batch_0)
|
||||||
|
pub_1.publish(batch_1)
|
||||||
|
|
||||||
|
# Receive events from rank 0
|
||||||
|
result_0 = sub_0.receive_one(timeout=200)
|
||||||
|
assert result_0 is not None, "No message received from rank 0"
|
||||||
|
seq_0, received_0 = result_0
|
||||||
|
|
||||||
|
# Receive events from rank 1
|
||||||
|
result_1 = sub_1.receive_one(timeout=200)
|
||||||
|
assert result_1 is not None, "No message received from rank 1"
|
||||||
|
seq_1, received_1 = result_1
|
||||||
|
|
||||||
|
# Verify DP rank tagging
|
||||||
|
assert received_0.data_parallel_rank == 0, (
|
||||||
|
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
|
||||||
|
assert received_1.data_parallel_rank == 1, (
|
||||||
|
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
|
||||||
|
|
||||||
|
# Verify event content is correct
|
||||||
|
assert len(
|
||||||
|
received_0.events) == 2, "Wrong number of events from rank 0"
|
||||||
|
assert len(
|
||||||
|
received_1.events) == 3, "Wrong number of events from rank 1"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
pub_0.shutdown()
|
||||||
|
pub_1.shutdown()
|
||||||
|
sub_0.close()
|
||||||
|
sub_1.close()
|
||||||
|
|||||||
@ -12,8 +12,10 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.utils import multi_gpu_test
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
|
||||||
|
ZmqEventPublisher)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
|
|||||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||||
|
|
||||||
|
|
||||||
def make_request(params: SamplingParams) -> EngineCoreRequest:
|
def make_request(
|
||||||
|
params: SamplingParams,
|
||||||
|
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
|
||||||
|
if not prompt_tokens_ids:
|
||||||
|
prompt_tokens_ids = PROMPT_TOKENS
|
||||||
|
|
||||||
return EngineCoreRequest(
|
return EngineCoreRequest(
|
||||||
request_id=str(uuid.uuid4()),
|
request_id=str(uuid.uuid4()),
|
||||||
prompt_token_ids=PROMPT_TOKENS,
|
prompt_token_ids=prompt_tokens_ids,
|
||||||
mm_inputs=None,
|
mm_inputs=None,
|
||||||
mm_hashes=None,
|
mm_hashes=None,
|
||||||
mm_placeholders=None,
|
mm_placeholders=None,
|
||||||
@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
|
||||||
|
|
||||||
|
while True:
|
||||||
|
engine_core_outputs = (await client.get_output_async()).outputs
|
||||||
|
|
||||||
|
if len(engine_core_outputs) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add outputs to the dict
|
||||||
|
for out in engine_core_outputs:
|
||||||
|
outputs[out.request_id].append(out)
|
||||||
|
|
||||||
|
# Check if all request IDs in outputs have finished
|
||||||
|
if all(outs and outs[-1].finished for outs in outputs.values()):
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
# Dummy utility function to monkey-patch into engine core.
|
# Dummy utility function to monkey-patch into engine core.
|
||||||
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
|
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
|
||||||
print(f"echo util function called: {msg}, {err_msg}")
|
print(f"echo util function called: {msg}, {err_msg}")
|
||||||
@ -273,10 +299,12 @@ def test_kv_cache_events(
|
|||||||
block_size = 16
|
block_size = 16
|
||||||
num_blocks = 2
|
num_blocks = 2
|
||||||
|
|
||||||
engine_args = EngineArgs(model=MODEL_NAME,
|
engine_args = EngineArgs(
|
||||||
enforce_eager=True,
|
model=MODEL_NAME,
|
||||||
enable_prefix_caching=True,
|
enforce_eager=True,
|
||||||
block_size=block_size)
|
enable_prefix_caching=True,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
engine_args.kv_events_config = publisher_config
|
engine_args.kv_events_config = publisher_config
|
||||||
|
|
||||||
vllm_config = engine_args.create_engine_config(
|
vllm_config = engine_args.create_engine_config(
|
||||||
@ -297,19 +325,8 @@ def test_kv_cache_events(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
custom_tokens = list(range(num_blocks * block_size))
|
custom_tokens = list(range(num_blocks * block_size))
|
||||||
request = EngineCoreRequest(
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
request_id=str(uuid.uuid4()),
|
request = make_request(sampling_params, custom_tokens)
|
||||||
prompt_token_ids=custom_tokens,
|
|
||||||
mm_inputs=None,
|
|
||||||
mm_hashes=None,
|
|
||||||
mm_placeholders=None,
|
|
||||||
sampling_params=SamplingParams(
|
|
||||||
max_tokens=1), # Short completion for speed
|
|
||||||
eos_token_id=None,
|
|
||||||
arrival_time=time.time(),
|
|
||||||
lora_request=None,
|
|
||||||
cache_salt=None,
|
|
||||||
)
|
|
||||||
client.add_request(request)
|
client.add_request(request)
|
||||||
|
|
||||||
outputs: dict[str, list] = {request.request_id: []}
|
outputs: dict[str, list] = {request.request_id: []}
|
||||||
@ -321,24 +338,130 @@ def test_kv_cache_events(
|
|||||||
seq, received = result
|
seq, received = result
|
||||||
|
|
||||||
assert seq == 0, "Sequence number mismatch"
|
assert seq == 0, "Sequence number mismatch"
|
||||||
assert len(received.events) == 1, (
|
assert (len(received.events) == 1
|
||||||
"We should have exactly one BlockStored event")
|
), "We should have exactly one BlockStored event"
|
||||||
event = received.events[0]
|
event = received.events[0]
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
event, BlockStored), ("We should have a BlockStored event")
|
event, BlockStored), "We should have a BlockStored event"
|
||||||
assert len(event.block_hashes) == num_blocks, (
|
assert (len(event.block_hashes) == num_blocks
|
||||||
"We should have a BlockStored event with 2 block_hashes")
|
), "We should have a BlockStored event with 2 block_hashes"
|
||||||
assert event.block_size == block_size, (
|
assert (event.block_size == block_size
|
||||||
"Block size should be the same as the block size")
|
), "Block size should be the same as the block size"
|
||||||
assert event.parent_block_hash is None, (
|
assert (event.parent_block_hash
|
||||||
"Parent block hash should be None")
|
is None), "Parent block hash should be None"
|
||||||
assert event.lora_id is None, "Lora id should be None"
|
assert event.lora_id is None, "Lora id should be None"
|
||||||
assert len(event.token_ids) == num_blocks * block_size, (
|
assert (len(event.token_ids) == num_blocks * block_size
|
||||||
"Token ids should be the same as the custom tokens")
|
), "Token ids should be the same as the custom tokens"
|
||||||
assert event.token_ids == custom_tokens, (
|
assert (event.token_ids == custom_tokens
|
||||||
"Token ids should be the same as the custom tokens")
|
), "Token ids should be the same as the custom tokens"
|
||||||
finally:
|
finally:
|
||||||
client.shutdown()
|
client.shutdown()
|
||||||
|
subscriber.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"multiprocessing_mode,publisher_config",
|
||||||
|
[(True, "tcp")],
|
||||||
|
indirect=["publisher_config"],
|
||||||
|
)
|
||||||
|
@multi_gpu_test(num_gpus=4)
|
||||||
|
async def test_kv_cache_events_dp(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
multiprocessing_mode: bool,
|
||||||
|
publisher_config,
|
||||||
|
):
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
block_size = 16
|
||||||
|
num_blocks = 2
|
||||||
|
dp_size = 2
|
||||||
|
tp_size = 2
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
data_parallel_size=dp_size,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
engine_args.kv_events_config = publisher_config
|
||||||
|
|
||||||
|
vllm_config = engine_args.create_engine_config(
|
||||||
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
|
||||||
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
client = EngineCoreClient.make_client(
|
||||||
|
multiprocess_mode=multiprocessing_mode,
|
||||||
|
asyncio_mode=True,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=False,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Build endpoints for all DP ranks
|
||||||
|
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||||
|
endpoints = []
|
||||||
|
for i in range(dp_size):
|
||||||
|
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
|
||||||
|
base_endpoint, i)
|
||||||
|
endpoints.append(offset_endpoint)
|
||||||
|
|
||||||
|
subscriber = MockSubscriber(endpoints,
|
||||||
|
topic=publisher_config.topic,
|
||||||
|
decode_type=KVEventBatch)
|
||||||
|
|
||||||
|
try:
|
||||||
|
custom_tokens = list(range(num_blocks * block_size))
|
||||||
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
all_request_ids = []
|
||||||
|
|
||||||
|
# Create and add 25 requests
|
||||||
|
# NOTE: attempts to force routing to both dp groups but can be flaky
|
||||||
|
for i in range(25):
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
request = make_request(sampling_params, custom_tokens)
|
||||||
|
await client.add_request_async(request)
|
||||||
|
all_request_ids.append(request.request_id)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Initialize outputs dict for all requests
|
||||||
|
outputs: dict[str, list] = {
|
||||||
|
req_id: []
|
||||||
|
for req_id in all_request_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
print("processing requests...")
|
||||||
|
await asyncio.wait_for(loop_until_fully_done_async(
|
||||||
|
client, outputs),
|
||||||
|
timeout=20.0)
|
||||||
|
|
||||||
|
# Receive from subscriber until no more messages
|
||||||
|
print("collecting results...")
|
||||||
|
results = []
|
||||||
|
while True:
|
||||||
|
result = subscriber.receive_one(timeout=1)
|
||||||
|
print(result)
|
||||||
|
if result is None:
|
||||||
|
break
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Collect all events and data_parallel_ranks from all results
|
||||||
|
all_dp_ranks = [
|
||||||
|
received.data_parallel_rank for (_, received) in results
|
||||||
|
]
|
||||||
|
unique_dps = set(all_dp_ranks)
|
||||||
|
assert (
|
||||||
|
len(unique_dps) == 2
|
||||||
|
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
client.shutdown()
|
||||||
|
subscriber.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(20)
|
@pytest.mark.timeout(20)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class EventBatch(
|
|||||||
):
|
):
|
||||||
ts: float
|
ts: float
|
||||||
events: list[Any]
|
events: list[Any]
|
||||||
|
data_parallel_rank: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class KVCacheEvent(
|
class KVCacheEvent(
|
||||||
@ -60,7 +61,22 @@ class KVEventBatch(EventBatch):
|
|||||||
|
|
||||||
|
|
||||||
class EventPublisher(ABC):
|
class EventPublisher(ABC):
|
||||||
"""Lightweight publisher for EventBatch batches."""
|
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||||
|
support.
|
||||||
|
|
||||||
|
In data parallel setups, each DP rank runs its own EventPublisher instance
|
||||||
|
to avoid duplicate events and ensure proper event attribution:
|
||||||
|
|
||||||
|
- Each DP rank creates a separate publisher
|
||||||
|
- Publishers automatically annotate events with their data_parallel_rank
|
||||||
|
- This allows consumers to distinguish events from different DP ranks
|
||||||
|
|
||||||
|
The publisher is responsible for adding DP metadata since the scheduler
|
||||||
|
operates independently of DP topology and shouldn't need DP awareness.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data_parallel_rank: int = 0) -> None:
|
||||||
|
self._data_parallel_rank = data_parallel_rank
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def publish(self, events: EventBatch) -> None:
|
def publish(self, events: EventBatch) -> None:
|
||||||
@ -113,6 +129,7 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
data_parallel_rank: int,
|
||||||
endpoint: str = "tcp://*:5557",
|
endpoint: str = "tcp://*:5557",
|
||||||
replay_endpoint: Optional[str] = None,
|
replay_endpoint: Optional[str] = None,
|
||||||
buffer_steps: int = 10_000,
|
buffer_steps: int = 10_000,
|
||||||
@ -121,6 +138,7 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
topic: str = "",
|
topic: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
# Storage
|
# Storage
|
||||||
|
super().__init__(data_parallel_rank)
|
||||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||||
|
|
||||||
@ -128,8 +146,11 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
self._ctx = zmq.Context.instance()
|
self._ctx = zmq.Context.instance()
|
||||||
self._pub: Optional[zmq.Socket] = None
|
self._pub: Optional[zmq.Socket] = None
|
||||||
self._replay: Optional[zmq.Socket] = None
|
self._replay: Optional[zmq.Socket] = None
|
||||||
self._endpoint = endpoint
|
self._dp_rank = data_parallel_rank
|
||||||
self._replay_endpoint = replay_endpoint
|
|
||||||
|
self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
|
||||||
|
self._replay_endpoint = self.offset_endpoint_port(
|
||||||
|
replay_endpoint, self._dp_rank)
|
||||||
self._hwm = hwm
|
self._hwm = hwm
|
||||||
self._socket_setup()
|
self._socket_setup()
|
||||||
|
|
||||||
@ -149,6 +170,8 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
def publish(self, events: EventBatch) -> None:
|
def publish(self, events: EventBatch) -> None:
|
||||||
if not self._running:
|
if not self._running:
|
||||||
raise RuntimeError("Publisher is closed")
|
raise RuntimeError("Publisher is closed")
|
||||||
|
if events.data_parallel_rank is None:
|
||||||
|
events.data_parallel_rank = self._data_parallel_rank
|
||||||
self._event_queue.put(events)
|
self._event_queue.put(events)
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
@ -191,11 +214,12 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
self._pub.set_hwm(self._hwm)
|
self._pub.set_hwm(self._hwm)
|
||||||
# Heuristic: bind if wildcard / * present, else connect.
|
# Heuristic: bind if wildcard / * present, else connect.
|
||||||
# bind stable, connect volatile convention
|
# bind stable, connect volatile convention
|
||||||
if ("*" in self._endpoint or "::" in self._endpoint
|
if (self._endpoint is not None
|
||||||
or self._endpoint.startswith("ipc://")
|
and ("*" in self._endpoint or "::" in self._endpoint
|
||||||
or self._endpoint.startswith("inproc://")):
|
or self._endpoint.startswith("ipc://")
|
||||||
|
or self._endpoint.startswith("inproc://"))):
|
||||||
self._pub.bind(self._endpoint)
|
self._pub.bind(self._endpoint)
|
||||||
else:
|
elif self._endpoint is not None:
|
||||||
self._pub.connect(self._endpoint)
|
self._pub.connect(self._endpoint)
|
||||||
|
|
||||||
# Set up replay socket: use ROUTER
|
# Set up replay socket: use ROUTER
|
||||||
@ -266,6 +290,38 @@ class ZmqEventPublisher(EventPublisher):
|
|||||||
# receiving payload is (-1, b""")
|
# receiving payload is (-1, b""")
|
||||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def offset_endpoint_port(endpoint: Optional[str],
|
||||||
|
data_parallel_rank: int) -> Optional[str]:
|
||||||
|
"""Helper function to offset the port in an endpoint by
|
||||||
|
the data parallel rank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: The endpoint string
|
||||||
|
(e.g., "tcp://*:5557" or "inproc://cache")
|
||||||
|
data_parallel_rank: The data parallel rank to offset by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The endpoint with the port offset by data_parallel_rank
|
||||||
|
or suffix appended
|
||||||
|
"""
|
||||||
|
# Do nothing if input is None or data_parallel_rank is 0
|
||||||
|
if not endpoint or data_parallel_rank == 0:
|
||||||
|
return endpoint
|
||||||
|
|
||||||
|
if "inproc" in endpoint:
|
||||||
|
return f"{endpoint}_dp{data_parallel_rank}"
|
||||||
|
if "tcp" in endpoint:
|
||||||
|
if endpoint and ":" in endpoint:
|
||||||
|
# Get everything after the last colon (the port)
|
||||||
|
last_colon_idx = endpoint.rfind(":")
|
||||||
|
base_addr = endpoint[:last_colon_idx]
|
||||||
|
base_port = int(endpoint[last_colon_idx + 1:])
|
||||||
|
new_port = base_port + data_parallel_rank
|
||||||
|
return f"{base_addr}:{new_port}"
|
||||||
|
return endpoint
|
||||||
|
raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")
|
||||||
|
|
||||||
|
|
||||||
class EventPublisherFactory:
|
class EventPublisherFactory:
|
||||||
_registry: dict[str, Callable[..., EventPublisher]] = {
|
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||||
@ -281,7 +337,9 @@ class EventPublisherFactory:
|
|||||||
cls._registry[name] = ctor
|
cls._registry[name] = ctor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
|
def create(cls,
|
||||||
|
config: Optional[KVEventsConfig],
|
||||||
|
data_parallel_rank: int = 0) -> EventPublisher:
|
||||||
"""Create publisher from a config mapping."""
|
"""Create publisher from a config mapping."""
|
||||||
if not config:
|
if not config:
|
||||||
return NullEventPublisher()
|
return NullEventPublisher()
|
||||||
@ -294,4 +352,5 @@ class EventPublisherFactory:
|
|||||||
constructor = cls._registry[kind]
|
constructor = cls._registry[kind]
|
||||||
except KeyError as exc:
|
except KeyError as exc:
|
||||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||||
return constructor(**config_dict)
|
return constructor(data_parallel_rank=data_parallel_rank,
|
||||||
|
**config_dict)
|
||||||
|
|||||||
@ -80,7 +80,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||||
|
|
||||||
self.kv_event_publisher = EventPublisherFactory.create(
|
self.kv_event_publisher = EventPublisherFactory.create(
|
||||||
self.kv_events_config)
|
self.kv_events_config,
|
||||||
|
vllm_config.parallel_config.data_parallel_rank,
|
||||||
|
)
|
||||||
|
|
||||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user