mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[V1][Metrics] add support for kv event publishing (#16750)
Signed-off-by: alec-flowers <aflowers@nvidia.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
77073c77bc
commit
0be6d05b5e
86
examples/online_serving/kv_events.sh
Normal file
86
examples/online_serving/kv_events.sh
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# This file demonstrates the KV cache event publishing
|
||||||
|
# We will launch a vllm instances configured to publish KV cache
|
||||||
|
# events and launch a simple subscriber to log those events.
|
||||||
|
|
||||||
|
set -xe
|
||||||
|
|
||||||
|
echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
|
||||||
|
|
||||||
|
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||||
|
trap 'cleanup' INT
|
||||||
|
|
||||||
|
# Cleanup function
|
||||||
|
cleanup() {
|
||||||
|
echo "Caught Ctrl+C, cleaning up..."
|
||||||
|
# Cleanup commands
|
||||||
|
pgrep python | xargs kill -9
|
||||||
|
pkill -f python
|
||||||
|
echo "Cleanup complete. Exiting."
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
|
||||||
|
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||||
|
|
||||||
|
# a function that waits vLLM server to start
|
||||||
|
wait_for_server() {
|
||||||
|
local port=$1
|
||||||
|
timeout 1200 bash -c "
|
||||||
|
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done" && return 0 || return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
vllm serve $MODEL_NAME \
|
||||||
|
--port 8100 \
|
||||||
|
--max-model-len 100 \
|
||||||
|
--enforce-eager \
|
||||||
|
--gpu-memory-utilization 0.8 \
|
||||||
|
--trust-remote-code \
|
||||||
|
--kv-events-config \
|
||||||
|
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &
|
||||||
|
|
||||||
|
wait_for_server 8100
|
||||||
|
|
||||||
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
|
||||||
|
python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# serve two example requests
|
||||||
|
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "'"$MODEL_NAME"'",
|
||||||
|
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
|
||||||
|
"max_tokens": 80,
|
||||||
|
"temperature": 0
|
||||||
|
}')
|
||||||
|
|
||||||
|
output2=$(curl -X POST -s http://localhost:8100/v1/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "'"$MODEL_NAME"'",
|
||||||
|
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
|
||||||
|
"max_tokens": 80,
|
||||||
|
"temperature": 0
|
||||||
|
}')
|
||||||
|
|
||||||
|
# Cleanup commands
|
||||||
|
pkill -9 -u "$USER" -f python
|
||||||
|
pkill -9 -u "$USER" -f vllm
|
||||||
|
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
echo "Cleaned up"
|
||||||
|
|
||||||
|
# Print the outputs of the curl requests
|
||||||
|
echo ""
|
||||||
|
echo "Output of first request: $output1"
|
||||||
|
echo "Output of second request: $output2"
|
||||||
|
|
||||||
|
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
|
||||||
|
echo ""
|
||||||
114
examples/online_serving/kv_events_subscriber.py
Normal file
114
examples/online_serving/kv_events_subscriber.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import zmq
|
||||||
|
from msgspec.msgpack import Decoder
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Types copied from vllm.distributed.kv_events
|
||||||
|
#
|
||||||
|
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
|
||||||
|
gc=False):
|
||||||
|
ts: float
|
||||||
|
events: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheEvent(msgspec.Struct,
|
||||||
|
array_like=True,
|
||||||
|
omit_defaults=True,
|
||||||
|
gc=False,
|
||||||
|
tag=True):
|
||||||
|
"""Base class for all KV cache-related events"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockStored(KVCacheEvent):
|
||||||
|
block_hashes: list[int]
|
||||||
|
parent_block_hash: Optional[int]
|
||||||
|
token_ids: list[int]
|
||||||
|
block_size: int
|
||||||
|
lora_id: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockRemoved(KVCacheEvent):
|
||||||
|
block_hashes: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AllBlocksCleared(KVCacheEvent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KVEventBatch(EventBatch):
|
||||||
|
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||||
|
|
||||||
|
|
||||||
|
def process_event(event_batch):
|
||||||
|
print(f"Received event batch at {event_batch.ts}:")
|
||||||
|
for event in event_batch.events:
|
||||||
|
print(f" - {event}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
decoder = Decoder(type=KVEventBatch)
|
||||||
|
last_seq = -1
|
||||||
|
|
||||||
|
context = zmq.Context()
|
||||||
|
|
||||||
|
# Set up the main subscription socket
|
||||||
|
sub = context.socket(zmq.SUB)
|
||||||
|
sub.connect("tcp://localhost:5557")
|
||||||
|
topic = "kv-events"
|
||||||
|
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||||
|
|
||||||
|
# Initialize replay socket
|
||||||
|
replay = context.socket(zmq.REQ)
|
||||||
|
replay.connect("tcp://localhost:5558")
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(replay, zmq.POLLIN)
|
||||||
|
|
||||||
|
print("Listening for KV cache events on topic:", topic)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if sub.poll(50):
|
||||||
|
_, seq_bytes, payload = sub.recv_multipart()
|
||||||
|
seq = int.from_bytes(seq_bytes, "big")
|
||||||
|
|
||||||
|
if last_seq >= 0 and seq > last_seq + 1:
|
||||||
|
missed = seq - last_seq - 1
|
||||||
|
print(f"Missed {missed} messages"
|
||||||
|
f" (last: {last_seq}, current: {seq})")
|
||||||
|
|
||||||
|
replay.send((last_seq + 1).to_bytes(8, "big"))
|
||||||
|
|
||||||
|
while poller.poll(timeout=200):
|
||||||
|
seq_bytes, replay_payload = replay.recv_multipart()
|
||||||
|
if not replay_payload:
|
||||||
|
# End of replay marker is sent as an empty frame
|
||||||
|
# for the payload
|
||||||
|
break
|
||||||
|
|
||||||
|
replay_seq = int.from_bytes(seq_bytes, "big")
|
||||||
|
|
||||||
|
if replay_seq > last_seq:
|
||||||
|
event_batch = decoder.decode(replay_payload)
|
||||||
|
process_event(event_batch)
|
||||||
|
last_seq = replay_seq
|
||||||
|
if replay_seq >= seq - 1:
|
||||||
|
break
|
||||||
|
|
||||||
|
event_batch = decoder.decode(payload)
|
||||||
|
process_event(event_batch)
|
||||||
|
|
||||||
|
# ... do other periodic work or check for shutdown ...
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Interrupted")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print("Error decoding message:", e)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
145
tests/distributed/conftest.py
Normal file
145
tests/distributed/conftest.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import random
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import msgspec.msgpack
|
||||||
|
import pytest
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm.config import KVEventsConfig
|
||||||
|
from vllm.distributed.kv_events import EventPublisherFactory
|
||||||
|
|
||||||
|
from .test_events import SampleBatch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def random_port():
|
||||||
|
"""Generate a random port number for testing"""
|
||||||
|
return random.randint(10000, 60000)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def publisher_config(random_port, request):
|
||||||
|
"""Create a publisher config with inproc transport"""
|
||||||
|
how = request.param if hasattr(request, "param") else "inproc"
|
||||||
|
|
||||||
|
if how == "inproc":
|
||||||
|
endpoint = f"inproc://test-{random_port}"
|
||||||
|
replay_endpoint = endpoint + "-replay"
|
||||||
|
else:
|
||||||
|
endpoint = f"tcp://*:{random_port}"
|
||||||
|
replay_endpoint = f"tcp://*:{random_port + 1}"
|
||||||
|
|
||||||
|
return KVEventsConfig(enable_kv_cache_events=True,
|
||||||
|
publisher="zmq",
|
||||||
|
endpoint=endpoint,
|
||||||
|
replay_endpoint=replay_endpoint,
|
||||||
|
buffer_steps=100,
|
||||||
|
hwm=1000,
|
||||||
|
topic="test")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def publisher(publisher_config):
|
||||||
|
"""Create and return a publisher instance"""
|
||||||
|
pub = EventPublisherFactory.create(publisher_config)
|
||||||
|
yield pub
|
||||||
|
pub.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def subscriber(publisher_config):
|
||||||
|
"""Create and return a subscriber for testing"""
|
||||||
|
endpoint = publisher_config.endpoint
|
||||||
|
replay_endpoint = publisher_config.replay_endpoint
|
||||||
|
|
||||||
|
if endpoint.startswith("tcp://*"):
|
||||||
|
endpoint = endpoint.replace("*", "127.0.0.1")
|
||||||
|
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||||
|
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||||
|
|
||||||
|
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
|
||||||
|
yield sub
|
||||||
|
sub.close()
|
||||||
|
|
||||||
|
|
||||||
|
class MockSubscriber:
|
||||||
|
"""Helper class to receive and verify published events"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
pub_endpoint: str,
|
||||||
|
replay_endpoint: Optional[str] = None,
|
||||||
|
topic: str = "",
|
||||||
|
decode_type=SampleBatch):
|
||||||
|
self.ctx = zmq.Context.instance()
|
||||||
|
|
||||||
|
# Set up subscriber socket
|
||||||
|
self.sub = self.ctx.socket(zmq.SUB)
|
||||||
|
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
|
||||||
|
self.sub.connect(pub_endpoint)
|
||||||
|
|
||||||
|
# Set up replay socket if provided
|
||||||
|
self.replay = None
|
||||||
|
if replay_endpoint:
|
||||||
|
self.replay = self.ctx.socket(zmq.REQ)
|
||||||
|
self.replay.connect(replay_endpoint)
|
||||||
|
|
||||||
|
self.topic = topic
|
||||||
|
self.topic_bytes = topic.encode('utf-8')
|
||||||
|
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||||
|
self.last_seq = -1
|
||||||
|
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||||
|
|
||||||
|
def receive_one(self,
|
||||||
|
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||||
|
"""Receive a single message with timeout"""
|
||||||
|
if not self.sub.poll(timeout):
|
||||||
|
return None
|
||||||
|
|
||||||
|
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
|
||||||
|
assert topic_bytes == self.topic_bytes
|
||||||
|
|
||||||
|
seq = int.from_bytes(seq_bytes, "big")
|
||||||
|
data = self.decoder.decode(payload)
|
||||||
|
self.last_seq = seq
|
||||||
|
self.received_msgs.append((seq, data))
|
||||||
|
return seq, data
|
||||||
|
|
||||||
|
def request_replay(self, start_seq: int) -> None:
|
||||||
|
"""Request replay of messages starting from start_seq"""
|
||||||
|
if not self.replay:
|
||||||
|
raise ValueError("Replay socket not initialized")
|
||||||
|
|
||||||
|
self.replay.send(start_seq.to_bytes(8, "big"))
|
||||||
|
|
||||||
|
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
|
||||||
|
"""Receive replayed messages"""
|
||||||
|
if not self.replay:
|
||||||
|
raise ValueError("Replay socket not initialized")
|
||||||
|
|
||||||
|
replayed: list[tuple[int, SampleBatch]] = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if not self.replay.poll(1000):
|
||||||
|
break
|
||||||
|
|
||||||
|
frames = self.replay.recv_multipart()
|
||||||
|
if not frames or not frames[-1]:
|
||||||
|
# End of replay marker
|
||||||
|
break
|
||||||
|
|
||||||
|
seq_bytes, payload = frames
|
||||||
|
seq = int.from_bytes(seq_bytes, "big")
|
||||||
|
data = self.decoder.decode(payload)
|
||||||
|
replayed.append((seq, data))
|
||||||
|
except zmq.ZMQError as _:
|
||||||
|
break
|
||||||
|
|
||||||
|
return replayed
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Clean up resources"""
|
||||||
|
self.sub.close()
|
||||||
|
if self.replay:
|
||||||
|
self.replay.close()
|
||||||
193
tests/distributed/test_events.py
Normal file
193
tests/distributed/test_events.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||||
|
NullEventPublisher)
|
||||||
|
|
||||||
|
|
||||||
|
class EventSample(
|
||||||
|
msgspec.Struct,
|
||||||
|
tag=True, # type: ignore
|
||||||
|
array_like=True # type: ignore
|
||||||
|
):
|
||||||
|
"""Test event for publisher testing"""
|
||||||
|
id: int
|
||||||
|
value: str
|
||||||
|
|
||||||
|
|
||||||
|
class SampleBatch(EventBatch):
|
||||||
|
"""Test event batch for publisher testing"""
|
||||||
|
events: list[EventSample]
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_events(count: int) -> SampleBatch:
|
||||||
|
"""Create a batch of test events"""
|
||||||
|
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
|
||||||
|
return SampleBatch(ts=time.time(), events=events)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_publishing(publisher, subscriber):
|
||||||
|
"""Test basic event publishing works"""
|
||||||
|
|
||||||
|
test_batch = create_test_events(5)
|
||||||
|
publisher.publish(test_batch)
|
||||||
|
|
||||||
|
result = subscriber.receive_one(timeout=1000)
|
||||||
|
assert result is not None, "No message received"
|
||||||
|
|
||||||
|
seq, received = result
|
||||||
|
assert seq == 0, "Sequence number mismatch"
|
||||||
|
assert received.ts == pytest.approx(test_batch.ts,
|
||||||
|
abs=0.1), ("Timestamp mismatch")
|
||||||
|
assert len(received.events) == len(
|
||||||
|
test_batch.events), ("Number of events mismatch")
|
||||||
|
|
||||||
|
for i, event in enumerate(received.events):
|
||||||
|
assert event.id == i, "Event id mismatch"
|
||||||
|
assert event.value == f"test-{i}", "Event value mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_events(publisher, subscriber):
|
||||||
|
"""Test publishing and receiving multiple event batches"""
|
||||||
|
for _ in range(10):
|
||||||
|
batch = create_test_events(2)
|
||||||
|
publisher.publish(batch)
|
||||||
|
|
||||||
|
received = []
|
||||||
|
for _ in range(10):
|
||||||
|
data = subscriber.receive_one(timeout=100)
|
||||||
|
if data:
|
||||||
|
received.append(data)
|
||||||
|
|
||||||
|
assert len(received) == 10, "Number of messages mismatch"
|
||||||
|
seqs = [seq for seq, _ in received]
|
||||||
|
assert seqs == list(range(10)), "Sequence numbers mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_replay_mechanism(publisher, subscriber):
|
||||||
|
"""Test the replay mechanism works correctly"""
|
||||||
|
for _ in range(19):
|
||||||
|
batch = create_test_events(1)
|
||||||
|
publisher.publish(batch)
|
||||||
|
|
||||||
|
time.sleep(0.5) # Need publisher to process above requests
|
||||||
|
subscriber.request_replay(10)
|
||||||
|
|
||||||
|
batch = create_test_events(1)
|
||||||
|
publisher.publish(batch) # 20th message
|
||||||
|
|
||||||
|
replayed = subscriber.receive_replay()
|
||||||
|
|
||||||
|
assert len(replayed) > 0, "No replayed messages received"
|
||||||
|
seqs = [seq for seq, _ in replayed]
|
||||||
|
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||||
|
assert seqs == list(range(min(seqs),
|
||||||
|
max(seqs) +
|
||||||
|
1)), ("Replayed messages not consecutive")
|
||||||
|
|
||||||
|
|
||||||
|
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||||
|
"""Test buffer limit behavior"""
|
||||||
|
buffer_size = publisher_config.buffer_steps
|
||||||
|
|
||||||
|
# Publish more events than the buffer can hold
|
||||||
|
for i in range(buffer_size + 10):
|
||||||
|
batch = create_test_events(1)
|
||||||
|
publisher.publish(batch)
|
||||||
|
|
||||||
|
time.sleep(0.5) # Need publisher to process above requests
|
||||||
|
subscriber.request_replay(0)
|
||||||
|
|
||||||
|
batch = create_test_events(1)
|
||||||
|
publisher.publish(batch)
|
||||||
|
|
||||||
|
replayed = subscriber.receive_replay()
|
||||||
|
|
||||||
|
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
|
||||||
|
|
||||||
|
oldest_seq = min(seq for seq, _ in replayed)
|
||||||
|
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
|
||||||
|
|
||||||
|
|
||||||
|
def test_topic_filtering(publisher_config):
|
||||||
|
"""
|
||||||
|
Test that a subscriber only receives messages matching its topic filter
|
||||||
|
"""
|
||||||
|
publisher_config.replay_endpoint = None
|
||||||
|
|
||||||
|
cfg = publisher_config.model_copy()
|
||||||
|
cfg.topic = "foo"
|
||||||
|
pub = EventPublisherFactory.create(cfg)
|
||||||
|
|
||||||
|
from .conftest import MockSubscriber
|
||||||
|
sub_foo = MockSubscriber(cfg.endpoint, None, "foo")
|
||||||
|
sub_bar = MockSubscriber(cfg.endpoint, None, "bar")
|
||||||
|
|
||||||
|
try:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
pub.publish(create_test_events(1))
|
||||||
|
|
||||||
|
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||||
|
assert all(msg is not None for msg in foo_received), (
|
||||||
|
"Subscriber with matching topic should receive messages")
|
||||||
|
|
||||||
|
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||||
|
assert all(msg is None for msg in bar_received), (
|
||||||
|
"Subscriber with non-matching topic should receive no messages")
|
||||||
|
finally:
|
||||||
|
pub.shutdown()
|
||||||
|
sub_foo.close()
|
||||||
|
sub_bar.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_high_volume(publisher, subscriber):
|
||||||
|
"""Test publishing and receiving a high volume of events"""
|
||||||
|
num_batches = 10_000
|
||||||
|
events_per_batch = 100
|
||||||
|
|
||||||
|
# Publish events in a separate thread to not block
|
||||||
|
def publish_events():
|
||||||
|
for i in range(num_batches):
|
||||||
|
batch = create_test_events(events_per_batch)
|
||||||
|
publisher.publish(batch)
|
||||||
|
# Small delay to avoid overwhelming
|
||||||
|
if i % 100 == 0:
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
received: list[tuple[int, SampleBatch]] = []
|
||||||
|
|
||||||
|
publisher_thread = threading.Thread(target=publish_events)
|
||||||
|
publisher_thread.start()
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while len(received) < num_batches:
|
||||||
|
if time.time() - start_time > 10: # Timeout after 10 seconds
|
||||||
|
break
|
||||||
|
|
||||||
|
result = subscriber.receive_one(timeout=100)
|
||||||
|
if result:
|
||||||
|
received.append(result)
|
||||||
|
|
||||||
|
publisher_thread.join()
|
||||||
|
|
||||||
|
assert len(received) >= num_batches * 0.9, (
|
||||||
|
"We should have received most messages")
|
||||||
|
|
||||||
|
seqs = [seq for seq, _ in received]
|
||||||
|
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||||
|
|
||||||
|
|
||||||
|
def test_null_publisher():
|
||||||
|
"""Test that NullEventPublisher can be used without errors"""
|
||||||
|
publisher = NullEventPublisher()
|
||||||
|
|
||||||
|
# This should not raise any errors
|
||||||
|
batch = create_test_events(5)
|
||||||
|
publisher.publish(batch)
|
||||||
|
publisher.shutdown()
|
||||||
@ -6,6 +6,7 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256
|
from vllm.utils import sha256
|
||||||
@ -48,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
|||||||
num_blocks=num_blocks,
|
num_blocks=num_blocks,
|
||||||
tensors={},
|
tensors={},
|
||||||
kv_cache_groups=[
|
kv_cache_groups=[
|
||||||
KVCacheGroupSpec(['layer'],
|
KVCacheGroupSpec(
|
||||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
["layer"],
|
||||||
False))
|
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||||
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -783,6 +785,60 @@ def test_prefix_cache_stats_disabled():
|
|||||||
assert manager.prefix_cache_stats is None
|
assert manager.prefix_cache_stats is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||||
|
def test_kv_cache_events(blocks_to_cache: int):
|
||||||
|
block_size = 16
|
||||||
|
num_blocks = blocks_to_cache + 1
|
||||||
|
|
||||||
|
# Allocate Blocks
|
||||||
|
# Should see a single block stored event with a blocks_to_cache number of
|
||||||
|
# block hashes
|
||||||
|
# take_events should reset the kv_event_queue
|
||||||
|
manager = KVCacheManager(
|
||||||
|
make_kv_cache_config(block_size, num_blocks),
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
enable_kv_cache_events=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tokens = block_size * blocks_to_cache
|
||||||
|
req0 = make_request("0", list(range(num_tokens)))
|
||||||
|
_ = manager.allocate_slots(req0, num_tokens)
|
||||||
|
events = manager.take_events()
|
||||||
|
|
||||||
|
block = events[-1]
|
||||||
|
assert (len(block.block_hashes) == blocks_to_cache == len(
|
||||||
|
manager.block_pool.cached_block_hash_to_block))
|
||||||
|
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
|
||||||
|
assert len(manager.block_pool.kv_event_queue) == 0
|
||||||
|
|
||||||
|
stored_block_hash = block.block_hashes
|
||||||
|
|
||||||
|
# Remove blocks and send another request
|
||||||
|
# Should see block_to_cache number of removed block events and a new block
|
||||||
|
# stored event
|
||||||
|
manager.free(req0)
|
||||||
|
req1 = make_request("1", list(range(num_tokens)))
|
||||||
|
_ = manager.allocate_slots(req1, num_tokens)
|
||||||
|
events = manager.take_events()
|
||||||
|
|
||||||
|
for blocks in events[:-1]:
|
||||||
|
assert blocks.block_hashes[0] in stored_block_hash
|
||||||
|
assert len(events) == blocks_to_cache + 1
|
||||||
|
assert (isinstance(events[-2], BlockRemoved))
|
||||||
|
assert (len(events[-1].block_hashes) == blocks_to_cache == len(
|
||||||
|
manager.block_pool.cached_block_hash_to_block))
|
||||||
|
|
||||||
|
# All Blocks Cleared
|
||||||
|
# Should see a single all blocks cleared event
|
||||||
|
manager.free(req1)
|
||||||
|
manager.reset_prefix_cache()
|
||||||
|
events = manager.take_events()
|
||||||
|
|
||||||
|
assert isinstance(events[-1], AllBlocksCleared)
|
||||||
|
assert len(manager.block_pool.cached_block_hash_to_block) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_eagle_enabled_removes_last_block():
|
def test_eagle_enabled_removes_last_block():
|
||||||
"""Verify Eagle does NOT remove blocks when request
|
"""Verify Eagle does NOT remove blocks when request
|
||||||
length is divisible by block size."""
|
length is divisible by block size."""
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
|||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
|
||||||
|
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||||
|
|
||||||
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
|
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
|
||||||
|
|
||||||
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
|
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import pytest
|
|||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
|||||||
SyncMPClient)
|
SyncMPClient)
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
|
from ...distributed.conftest import MockSubscriber
|
||||||
from ...utils import create_new_process_for_each_test
|
from ...utils import create_new_process_for_each_test
|
||||||
|
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
|||||||
log_stats=True,
|
log_stats=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
MAX_TOKENS = 20
|
try:
|
||||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
MAX_TOKENS = 20
|
||||||
"""Normal Request Cycle."""
|
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||||
|
"""Normal Request Cycle."""
|
||||||
|
|
||||||
requests = [make_request(params) for _ in range(10)]
|
requests = [make_request(params) for _ in range(10)]
|
||||||
request_ids = [req.request_id for req in requests]
|
request_ids = [req.request_id for req in requests]
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
for request in requests:
|
for request in requests:
|
||||||
await client.add_request_async(request)
|
await client.add_request_async(request)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||||
await loop_until_done_async(client, outputs)
|
await loop_until_done_async(client, outputs)
|
||||||
|
|
||||||
for req_id in request_ids:
|
for req_id in request_ids:
|
||||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
|
||||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
|
||||||
"""Abort Request Cycle."""
|
|
||||||
|
|
||||||
# Add requests to the engine.
|
|
||||||
for idx, request in enumerate(requests):
|
|
||||||
await client.add_request_async(request)
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
if idx % 2 == 0:
|
|
||||||
await client.abort_requests_async([request.request_id])
|
|
||||||
|
|
||||||
outputs = {req_id: [] for req_id in request_ids}
|
|
||||||
await loop_until_done_async(client, outputs)
|
|
||||||
|
|
||||||
for idx, req_id in enumerate(request_ids):
|
|
||||||
if idx % 2 == 0:
|
|
||||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
|
||||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
|
||||||
else:
|
|
||||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||||
"""Utility method invocation"""
|
"""Abort Request Cycle."""
|
||||||
|
|
||||||
core_client: AsyncMPClient = client
|
# Add requests to the engine.
|
||||||
|
for idx, request in enumerate(requests):
|
||||||
|
await client.add_request_async(request)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
if idx % 2 == 0:
|
||||||
|
await client.abort_requests_async([request.request_id])
|
||||||
|
|
||||||
result = await core_client.call_utility_async("echo", "testarg")
|
outputs = {req_id: [] for req_id in request_ids}
|
||||||
assert result == "testarg"
|
await loop_until_done_async(client, outputs)
|
||||||
|
|
||||||
with pytest.raises(Exception) as e_info:
|
for idx, req_id in enumerate(request_ids):
|
||||||
await core_client.call_utility_async("echo", None, "help!")
|
if idx % 2 == 0:
|
||||||
|
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||||
|
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||||
|
else:
|
||||||
|
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||||
|
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||||
|
"""Utility method invocation"""
|
||||||
|
|
||||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
core_client: AsyncMPClient = client
|
||||||
|
|
||||||
|
result = await core_client.call_utility_async("echo", "testarg")
|
||||||
|
assert result == "testarg"
|
||||||
|
|
||||||
|
with pytest.raises(Exception) as e_info:
|
||||||
|
await core_client.call_utility_async("echo", None, "help!")
|
||||||
|
|
||||||
|
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||||
|
finally:
|
||||||
|
client.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"multiprocessing_mode,publisher_config",
|
||||||
|
[(True, "tcp"), (False, "inproc")],
|
||||||
|
indirect=["publisher_config"],
|
||||||
|
)
|
||||||
|
def test_kv_cache_events(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
multiprocessing_mode: bool,
|
||||||
|
publisher_config,
|
||||||
|
):
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
block_size = 16
|
||||||
|
num_blocks = 2
|
||||||
|
|
||||||
|
engine_args = EngineArgs(model=MODEL_NAME,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
block_size=block_size)
|
||||||
|
engine_args.kv_events_config = publisher_config
|
||||||
|
|
||||||
|
vllm_config = engine_args.create_engine_config(
|
||||||
|
UsageContext.UNKNOWN_CONTEXT)
|
||||||
|
|
||||||
|
executor_class = Executor.get_class(vllm_config)
|
||||||
|
client = EngineCoreClient.make_client(
|
||||||
|
multiprocess_mode=multiprocessing_mode,
|
||||||
|
asyncio_mode=False,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=False,
|
||||||
|
)
|
||||||
|
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||||
|
time.sleep(0.1)
|
||||||
|
subscriber = MockSubscriber(endpoint,
|
||||||
|
topic=publisher_config.topic,
|
||||||
|
decode_type=KVEventBatch)
|
||||||
|
|
||||||
|
try:
|
||||||
|
custom_tokens = list(range(num_blocks * block_size))
|
||||||
|
request = EngineCoreRequest(
|
||||||
|
request_id=str(uuid.uuid4()),
|
||||||
|
prompt_token_ids=custom_tokens,
|
||||||
|
mm_inputs=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
mm_placeholders=None,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=1), # Short completion for speed
|
||||||
|
eos_token_id=None,
|
||||||
|
arrival_time=time.time(),
|
||||||
|
lora_request=None,
|
||||||
|
)
|
||||||
|
client.add_request(request)
|
||||||
|
|
||||||
|
outputs: dict[str, list] = {request.request_id: []}
|
||||||
|
loop_until_done(client, outputs)
|
||||||
|
|
||||||
|
result = subscriber.receive_one(timeout=1000)
|
||||||
|
assert result is not None, "No message received"
|
||||||
|
|
||||||
|
seq, received = result
|
||||||
|
|
||||||
|
assert seq == 0, "Sequence number mismatch"
|
||||||
|
assert len(received.events) == 1, (
|
||||||
|
"We should have exactly one BlockStored event")
|
||||||
|
event = received.events[0]
|
||||||
|
assert isinstance(
|
||||||
|
event, BlockStored), ("We should have a BlockStored event")
|
||||||
|
assert len(event.block_hashes) == num_blocks, (
|
||||||
|
"We should have a BlockStored event with 2 block_hashes")
|
||||||
|
assert event.block_size == block_size, (
|
||||||
|
"Block size should be the same as the block size")
|
||||||
|
assert event.parent_block_hash is None, (
|
||||||
|
"Parent block hash should be None")
|
||||||
|
assert event.lora_id is None, "Lora id should be None"
|
||||||
|
assert len(event.token_ids) == num_blocks * block_size, (
|
||||||
|
"Token ids should be the same as the custom tokens")
|
||||||
|
assert event.token_ids == custom_tokens, (
|
||||||
|
"Token ids should be the same as the custom tokens")
|
||||||
|
finally:
|
||||||
|
client.shutdown()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(10)
|
@pytest.mark.timeout(10)
|
||||||
|
|||||||
@ -1958,6 +1958,8 @@ class SchedulerConfig:
|
|||||||
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
||||||
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
|
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
|
||||||
|
|
||||||
|
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
|
||||||
|
# or "mod.custom_class".
|
||||||
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
||||||
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
|
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
|
||||||
default scheduler. Can be a class directly or the path to a class of form
|
default scheduler. Can be a class directly or the path to a class of form
|
||||||
@ -3417,6 +3419,51 @@ class KVTransferConfig(BaseModel):
|
|||||||
return self.kv_connector_extra_config.get(key, default)
|
return self.kv_connector_extra_config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class KVEventsConfig(BaseModel):
|
||||||
|
"""Configuration for KV event publishing."""
|
||||||
|
|
||||||
|
enable_kv_cache_events: bool = False
|
||||||
|
"""If True, enable KV cache events for tracking block storage and removal.
|
||||||
|
Events can be published externally by zmq using the event publisher config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
publisher: str = "null"
|
||||||
|
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: str = "tcp://*:5557"
|
||||||
|
"""The zmq endpoint to use for publishing kv events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
replay_endpoint: Optional[str] = None
|
||||||
|
"""The zmq endpoint to use for replaying kv events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
buffer_steps: int = 10_000
|
||||||
|
"""The number of steps to cache for replay endpoint. Will only save
|
||||||
|
events from the last N steps for the replay endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hwm: int = 100_000
|
||||||
|
"""The zmq high water mark for the event publisher. After queueing N events,
|
||||||
|
events will start dropping if the consumer is not keeping up.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_queue_size: int = 100_000
|
||||||
|
"""The maximum number of events to queue while waiting for publishing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
topic: str = ""
|
||||||
|
"""The topic to use for the event publisher. Consumers can subscribe to
|
||||||
|
this topic to receive events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
||||||
|
"""Parse the CLI value for the event publisher config."""
|
||||||
|
return KVEventsConfig.model_validate_json(cli_value)
|
||||||
|
|
||||||
|
|
||||||
class CompilationLevel:
|
class CompilationLevel:
|
||||||
# constants for the levels of the compilation process
|
# constants for the levels of the compilation process
|
||||||
NO_COMPILATION = 0
|
NO_COMPILATION = 0
|
||||||
@ -3779,6 +3826,7 @@ class VllmConfig:
|
|||||||
init=True) # type: ignore
|
init=True) # type: ignore
|
||||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||||
init=True) # type: ignore
|
init=True) # type: ignore
|
||||||
|
kv_events_config: Optional[KVEventsConfig] = None
|
||||||
# some opaque config, only used to provide additional information
|
# some opaque config, only used to provide additional information
|
||||||
# for the hash computation, mainly used for testing, debugging or out of
|
# for the hash computation, mainly used for testing, debugging or out of
|
||||||
# tree config registration.
|
# tree config registration.
|
||||||
@ -4038,6 +4086,18 @@ class VllmConfig:
|
|||||||
if self.cache_config is not None:
|
if self.cache_config is not None:
|
||||||
self.cache_config.enable_prefix_caching = False
|
self.cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
|
if (self.kv_events_config
|
||||||
|
and self.kv_events_config.enable_kv_cache_events
|
||||||
|
and not self.cache_config.enable_prefix_caching):
|
||||||
|
logger.warning(
|
||||||
|
"KV cache events are on, but prefix caching is not enabled."
|
||||||
|
"Use --enable-prefix-caching to enable.")
|
||||||
|
if (self.kv_events_config and self.kv_events_config.publisher != "null"
|
||||||
|
and not self.kv_events_config.enable_kv_cache_events):
|
||||||
|
logger.warning("KV cache events are disabled,"
|
||||||
|
"but the scheduler is configured to publish them."
|
||||||
|
"Modify KVEventsConfig.enable_kv_cache_events"
|
||||||
|
"to True to enable.")
|
||||||
current_platform.check_and_update_config(self)
|
current_platform.check_and_update_config(self)
|
||||||
|
|
||||||
if not self.instance_id:
|
if not self.instance_id:
|
||||||
|
|||||||
295
vllm/distributed/kv_events.py
Normal file
295
vllm/distributed/kv_events.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
from itertools import count
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
import msgspec
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm.config import KVEventsConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EventBatch(
|
||||||
|
msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False, # type: ignore[call-arg]
|
||||||
|
):
|
||||||
|
ts: float
|
||||||
|
events: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
class KVCacheEvent(
|
||||||
|
msgspec.Struct,
|
||||||
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
gc=False, # type: ignore[call-arg]
|
||||||
|
tag=True):
|
||||||
|
"""Base class for all KV cache-related events"""
|
||||||
|
|
||||||
|
|
||||||
|
class BlockStored(KVCacheEvent):
|
||||||
|
block_hashes: list[int]
|
||||||
|
parent_block_hash: Optional[int]
|
||||||
|
token_ids: list[int]
|
||||||
|
block_size: int
|
||||||
|
lora_id: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
class BlockRemoved(KVCacheEvent):
|
||||||
|
block_hashes: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class AllBlocksCleared(KVCacheEvent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KVEventBatch(EventBatch):
|
||||||
|
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||||
|
|
||||||
|
|
||||||
|
class EventPublisher(ABC):
|
||||||
|
"""Lightweight publisher for EventBatch batches."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def publish(self, events: EventBatch) -> None:
|
||||||
|
"""Emit events in order.
|
||||||
|
|
||||||
|
Implementations should guarantee at-least-once delivery and
|
||||||
|
monotonic ordering (e.g., via sequence numbers).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the publisher."""
|
||||||
|
|
||||||
|
|
||||||
|
class NullEventPublisher(EventPublisher):
|
||||||
|
"""No-op implementation (default when disabled)."""
|
||||||
|
|
||||||
|
def publish(self, events) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class ZmqEventPublisher(EventPublisher):
|
||||||
|
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
||||||
|
|
||||||
|
Spawns a separate thread to handle publishing from a queue.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
endpoint:
|
||||||
|
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
||||||
|
connect.
|
||||||
|
replay_endpoint:
|
||||||
|
Optional ROUTER address for replay requests. When given, subscribers can
|
||||||
|
request missed batches by sending the starting sequence number as an
|
||||||
|
8-byte big-endian integer.
|
||||||
|
buffer_steps:
|
||||||
|
Number of past batches to keep for replay.
|
||||||
|
hwm:
|
||||||
|
ZeroMQ high-water-mark for PUB socket.
|
||||||
|
max_queue_size:
|
||||||
|
Maximum number of events to buffer in memory.
|
||||||
|
topic:
|
||||||
|
Topic to publish events to.
|
||||||
|
"""
|
||||||
|
SHUTDOWN_TIMEOUT: float = 1.0
|
||||||
|
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str = "tcp://*:5557",
|
||||||
|
replay_endpoint: Optional[str] = None,
|
||||||
|
buffer_steps: int = 10_000,
|
||||||
|
hwm: int = 100_000,
|
||||||
|
max_queue_size: int = 100_000,
|
||||||
|
topic: str = "",
|
||||||
|
) -> None:
|
||||||
|
# Storage
|
||||||
|
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||||
|
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||||
|
|
||||||
|
# ZMQ sockets
|
||||||
|
self._ctx = zmq.Context.instance()
|
||||||
|
self._pub: Optional[zmq.Socket] = None
|
||||||
|
self._replay: Optional[zmq.Socket] = None
|
||||||
|
self._endpoint = endpoint
|
||||||
|
self._replay_endpoint = replay_endpoint
|
||||||
|
self._hwm = hwm
|
||||||
|
|
||||||
|
# Payload
|
||||||
|
self._seq_gen = count()
|
||||||
|
self._topic_bytes = topic.encode('utf-8')
|
||||||
|
|
||||||
|
# Thread
|
||||||
|
self._running = True
|
||||||
|
logger.info("Starting ZMQ publisher thread")
|
||||||
|
|
||||||
|
self._thread = threading.Thread(target=self._publisher_thread,
|
||||||
|
daemon=True,
|
||||||
|
name="zmq-publisher")
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def publish(self, events: EventBatch) -> None:
|
||||||
|
if not self._running:
|
||||||
|
raise RuntimeError("Publisher is closed")
|
||||||
|
self._event_queue.put(events)
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Stop the publisher thread and clean up resources."""
|
||||||
|
self._running = False
|
||||||
|
self._event_queue.put_nowait(None)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
pending_items = True
|
||||||
|
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
||||||
|
pending_items = not self._event_queue.empty()
|
||||||
|
if pending_items:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if pending_items:
|
||||||
|
logger.warning(
|
||||||
|
"Warning: Queue still has %s items after %s seconds timeout",
|
||||||
|
self._event_queue.qsize(),
|
||||||
|
self.SHUTDOWN_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._thread.is_alive():
|
||||||
|
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
||||||
|
|
||||||
|
# Clean up ZMQ resources
|
||||||
|
try:
|
||||||
|
if self._pub is not None:
|
||||||
|
self._pub.close(linger=0)
|
||||||
|
if self._replay is not None:
|
||||||
|
self._replay.close(linger=0)
|
||||||
|
finally:
|
||||||
|
pass # Do not terminate context; other sockets may use it
|
||||||
|
|
||||||
|
def _socket_setup(self) -> None:
|
||||||
|
"""Initialize sockets
|
||||||
|
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
||||||
|
"""
|
||||||
|
if self._pub is None:
|
||||||
|
self._pub = self._ctx.socket(zmq.PUB)
|
||||||
|
self._pub.set_hwm(self._hwm)
|
||||||
|
# Heuristic: bind if wildcard / * present, else connect.
|
||||||
|
# bind stable, connect volatile convention
|
||||||
|
if ("*" in self._endpoint or "::" in self._endpoint
|
||||||
|
or self._endpoint.startswith("ipc://")
|
||||||
|
or self._endpoint.startswith("inproc://")):
|
||||||
|
self._pub.bind(self._endpoint)
|
||||||
|
else:
|
||||||
|
self._pub.connect(self._endpoint)
|
||||||
|
|
||||||
|
# Set up replay socket: use ROUTER
|
||||||
|
# 1) handles multiple REQ clients (identities)
|
||||||
|
# 2) lets us send back one request → many replies (streamed events)
|
||||||
|
# 3) works in our non‑blocking poll loop alongside PUB
|
||||||
|
if self._replay_endpoint is not None:
|
||||||
|
self._replay = self._ctx.socket(zmq.ROUTER)
|
||||||
|
self._replay.bind(self._replay_endpoint)
|
||||||
|
|
||||||
|
def _publisher_thread(self) -> None:
|
||||||
|
"""Background thread that processes the event queue."""
|
||||||
|
self._pack = msgspec.msgpack.Encoder()
|
||||||
|
self._socket_setup()
|
||||||
|
|
||||||
|
assert self._pub is not None # narrows type for mypy
|
||||||
|
|
||||||
|
while self._running or self._event_queue.qsize() > 0:
|
||||||
|
# --- replay (non-critical) ---------------------------------
|
||||||
|
if self._replay is not None and self._replay.poll(0):
|
||||||
|
try:
|
||||||
|
self._service_replay()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error in replay: %s", e)
|
||||||
|
|
||||||
|
# --- main queue (critical) ---------------------------------
|
||||||
|
try:
|
||||||
|
event = self._event_queue.get(timeout=0.1)
|
||||||
|
if event is None:
|
||||||
|
break # Sentinel received, exit thread
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
seq = next(self._seq_gen)
|
||||||
|
|
||||||
|
payload = self._pack.encode(event)
|
||||||
|
seq_bytes = seq.to_bytes(8, "big")
|
||||||
|
self._pub.send_multipart(
|
||||||
|
(self._topic_bytes, seq_bytes, payload))
|
||||||
|
|
||||||
|
self._buffer.append((seq, payload))
|
||||||
|
self._event_queue.task_done()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Publishing failed; back-off a bit to avoid a tight error loop
|
||||||
|
logger.exception("Error in publisher thread: %s", e)
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
def _service_replay(self) -> None:
|
||||||
|
"""If a replay request is waiting, send buffered batches."""
|
||||||
|
assert self._replay is not None # narrows type for mypy
|
||||||
|
|
||||||
|
frame = self._replay.recv_multipart()
|
||||||
|
if len(frame) != 3:
|
||||||
|
logger.warning("Invalid replay request: %s", frame)
|
||||||
|
return
|
||||||
|
client_id, _, start_seq_bytes = frame
|
||||||
|
start_seq = int.from_bytes(start_seq_bytes, "big")
|
||||||
|
|
||||||
|
for seq, buf in self._buffer:
|
||||||
|
if seq >= start_seq:
|
||||||
|
# [identity, empty_delim, seq_bytes, payload]
|
||||||
|
# (identity, empty_delim) are stripped off by the router
|
||||||
|
# receiving payload is (seq_bytes, payload)
|
||||||
|
self._replay.send_multipart(
|
||||||
|
(client_id, b"", seq.to_bytes(8, "big"), buf))
|
||||||
|
# Send end of sequence marker
|
||||||
|
# receiving payload is (-1, b""")
|
||||||
|
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||||
|
|
||||||
|
|
||||||
|
class EventPublisherFactory:
|
||||||
|
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||||
|
"null": NullEventPublisher,
|
||||||
|
"zmq": ZmqEventPublisher,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_publisher(cls, name: str,
|
||||||
|
ctor: Callable[..., EventPublisher]) -> None:
|
||||||
|
if name in cls._registry:
|
||||||
|
raise KeyError(f"publisher '{name}' already registered")
|
||||||
|
cls._registry[name] = ctor
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
|
||||||
|
"""Create publisher from a config mapping."""
|
||||||
|
if not config:
|
||||||
|
return NullEventPublisher()
|
||||||
|
|
||||||
|
config_dict = config.model_dump()
|
||||||
|
|
||||||
|
kind = config_dict.pop("publisher", "null")
|
||||||
|
config_dict.pop("enable_kv_cache_events")
|
||||||
|
try:
|
||||||
|
constructor = cls._registry[kind]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||||
|
return constructor(**config_dict)
|
||||||
@ -19,14 +19,14 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
|||||||
ConfigFormat, ConfigType, DecodingConfig, Device,
|
ConfigFormat, ConfigType, DecodingConfig, Device,
|
||||||
DeviceConfig, DistributedExecutorBackend,
|
DeviceConfig, DistributedExecutorBackend,
|
||||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||||
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
|
HfOverrides, KVEventsConfig, KVTransferConfig,
|
||||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
LoadConfig, LoadFormat, LoRAConfig, ModelConfig,
|
||||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
ModelDType, ModelImpl, MultiModalConfig,
|
||||||
PoolerConfig, PrefixCachingHashAlgo,
|
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||||
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
TaskOption, TokenizerMode, TokenizerPoolConfig,
|
||||||
get_field)
|
VllmConfig, get_attr_docs, get_field)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -353,6 +353,7 @@ class EngineArgs:
|
|||||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||||
|
|
||||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
kv_events_config: Optional[KVEventsConfig] = None
|
||||||
|
|
||||||
generation_config: str = ModelConfig.generation_config
|
generation_config: str = ModelConfig.generation_config
|
||||||
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
||||||
@ -769,6 +770,10 @@ class EngineArgs:
|
|||||||
default=None,
|
default=None,
|
||||||
help='The configurations for distributed KV cache '
|
help='The configurations for distributed KV cache '
|
||||||
'transfer. Should be a JSON string.')
|
'transfer. Should be a JSON string.')
|
||||||
|
parser.add_argument('--kv-events-config',
|
||||||
|
type=KVEventsConfig.from_cli,
|
||||||
|
default=None,
|
||||||
|
help='The configurations for event publishing.')
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--worker-cls',
|
'--worker-cls',
|
||||||
@ -1125,6 +1130,7 @@ class EngineArgs:
|
|||||||
prompt_adapter_config=prompt_adapter_config,
|
prompt_adapter_config=prompt_adapter_config,
|
||||||
compilation_config=self.compilation_config,
|
compilation_config=self.compilation_config,
|
||||||
kv_transfer_config=self.kv_transfer_config,
|
kv_transfer_config=self.kv_transfer_config,
|
||||||
|
kv_events_config=self.kv_events_config,
|
||||||
additional_config=self.additional_config,
|
additional_config=self.additional_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,8 @@ from collections import defaultdict
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||||
|
BlockStored, KVCacheEvent)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||||
KVCacheBlock,
|
KVCacheBlock,
|
||||||
@ -26,7 +28,12 @@ class BlockPool:
|
|||||||
enable_caching: Whether to enable prefix caching.
|
enable_caching: Whether to enable prefix caching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
enable_caching: bool,
|
||||||
|
enable_kv_cache_events: bool = False,
|
||||||
|
):
|
||||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
self.enable_caching = enable_caching
|
self.enable_caching = enable_caching
|
||||||
@ -56,6 +63,9 @@ class BlockPool:
|
|||||||
# avoid freeing it.
|
# avoid freeing it.
|
||||||
self.null_block = self.free_block_queue.popleft()
|
self.null_block = self.free_block_queue.popleft()
|
||||||
|
|
||||||
|
self.enable_kv_cache_events = enable_kv_cache_events
|
||||||
|
self.kv_event_queue: list[KVCacheEvent] = []
|
||||||
|
|
||||||
def get_cached_block(self,
|
def get_cached_block(self,
|
||||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||||
"""Get a cached block by the block hash, or None if cache miss.
|
"""Get a cached block by the block hash, or None if cache miss.
|
||||||
@ -116,6 +126,9 @@ class BlockPool:
|
|||||||
assert prev_block.block_hash is not None
|
assert prev_block.block_hash is not None
|
||||||
prev_block_hash_value = prev_block.block_hash.hash_value
|
prev_block_hash_value = prev_block.block_hash.hash_value
|
||||||
|
|
||||||
|
parent_block_hash = prev_block_hash_value
|
||||||
|
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||||
|
else None)
|
||||||
for i, blk in enumerate(new_full_blocks):
|
for i, blk in enumerate(new_full_blocks):
|
||||||
assert blk.block_hash is None
|
assert blk.block_hash is None
|
||||||
|
|
||||||
@ -153,8 +166,23 @@ class BlockPool:
|
|||||||
# Update and added the full block to the cache.
|
# Update and added the full block to the cache.
|
||||||
blk.block_hash = block_hash
|
blk.block_hash = block_hash
|
||||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||||
|
if new_hashes is not None:
|
||||||
|
new_hashes.append(block_hash.hash_value)
|
||||||
prev_block_hash_value = block_hash.hash_value
|
prev_block_hash_value = block_hash.hash_value
|
||||||
|
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
self.kv_event_queue.append(
|
||||||
|
BlockStored(
|
||||||
|
block_hashes=new_hashes,
|
||||||
|
parent_block_hash=parent_block_hash,
|
||||||
|
token_ids=request.
|
||||||
|
all_token_ids[num_cached_blocks *
|
||||||
|
block_size:num_full_blocks * block_size],
|
||||||
|
block_size=block_size,
|
||||||
|
lora_id=request.lora_request.id
|
||||||
|
if request.lora_request else None,
|
||||||
|
))
|
||||||
|
|
||||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||||
"""Get new blocks from the free block pool.
|
"""Get new blocks from the free block pool.
|
||||||
|
|
||||||
@ -206,6 +234,9 @@ class BlockPool:
|
|||||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||||
del self.cached_block_hash_to_block[block_hash]
|
del self.cached_block_hash_to_block[block_hash]
|
||||||
|
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
self.kv_event_queue.append(
|
||||||
|
BlockRemoved(block_hashes=[block_hash.hash_value]))
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -262,6 +293,10 @@ class BlockPool:
|
|||||||
block.reset_hash()
|
block.reset_hash()
|
||||||
|
|
||||||
logger.info("Successfully reset prefix cache")
|
logger.info("Successfully reset prefix cache")
|
||||||
|
|
||||||
|
if self.enable_kv_cache_events:
|
||||||
|
self.kv_event_queue.append(AllBlocksCleared())
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_num_free_blocks(self) -> int:
|
def get_num_free_blocks(self) -> int:
|
||||||
@ -279,3 +314,15 @@ class BlockPool:
|
|||||||
The KV cache usage (between 0.0 and 1.0).
|
The KV cache usage (between 0.0 and 1.0).
|
||||||
"""
|
"""
|
||||||
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
|
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
|
||||||
|
|
||||||
|
def take_events(self) -> list[KVCacheEvent]:
|
||||||
|
"""Atomically takes all events and clears the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of KV cache events.
|
||||||
|
"""
|
||||||
|
if not self.enable_kv_cache_events:
|
||||||
|
return []
|
||||||
|
events = self.kv_event_queue
|
||||||
|
self.kv_event_queue = []
|
||||||
|
return events
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv, sha256
|
from vllm.utils import cdiv, sha256
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
@ -27,6 +28,7 @@ class KVCacheManager:
|
|||||||
caching_hash_algo: str = "builtin",
|
caching_hash_algo: str = "builtin",
|
||||||
use_eagle: bool = False,
|
use_eagle: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
|
enable_kv_cache_events: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
||||||
"KVCacheManager does not support hybrid models with more than 1 "
|
"KVCacheManager does not support hybrid models with more than 1 "
|
||||||
@ -44,7 +46,9 @@ class KVCacheManager:
|
|||||||
# FIXME: make prefix cache stats conditional on log_stats
|
# FIXME: make prefix cache stats conditional on log_stats
|
||||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||||
|
|
||||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
|
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
|
||||||
|
enable_kv_cache_events)
|
||||||
|
|
||||||
self.specialized_manager = get_specialized_manager(
|
self.specialized_manager = get_specialized_manager(
|
||||||
kv_cache_spec=kv_cache_spec,
|
kv_cache_spec=kv_cache_spec,
|
||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
@ -383,3 +387,11 @@ class KVCacheManager:
|
|||||||
is finished, not when it is preempted.
|
is finished, not when it is preempted.
|
||||||
"""
|
"""
|
||||||
self.req_to_block_hashes.pop(request.request_id, None)
|
self.req_to_block_hashes.pop(request.request_id, None)
|
||||||
|
|
||||||
|
def take_events(self) -> list[KVCacheEvent]:
|
||||||
|
"""Take the KV cache events from the block pool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of KV cache events.
|
||||||
|
"""
|
||||||
|
return self.block_pool.take_events()
|
||||||
|
|||||||
@ -132,3 +132,8 @@ class SchedulerInterface(ABC):
|
|||||||
The SchedulerStats object is created for every scheduling step.
|
The SchedulerStats object is created for every scheduling step.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
"""Shutdown the scheduler."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from collections.abc import Iterable
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
KVConnectorFactory)
|
KVConnectorFactory)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||||
@ -48,6 +49,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.cache_config = vllm_config.cache_config
|
self.cache_config = vllm_config.cache_config
|
||||||
self.lora_config = vllm_config.lora_config
|
self.lora_config = vllm_config.lora_config
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
|
self.kv_events_config = vllm_config.kv_events_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
self.structured_output_manager = structured_output_manager
|
self.structured_output_manager = structured_output_manager
|
||||||
|
|
||||||
@ -62,6 +64,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.max_num_scheduled_tokens = \
|
self.max_num_scheduled_tokens = \
|
||||||
self.scheduler_config.max_num_batched_tokens
|
self.scheduler_config.max_num_batched_tokens
|
||||||
self.max_model_len = self.scheduler_config.max_model_len
|
self.max_model_len = self.scheduler_config.max_model_len
|
||||||
|
self.enable_kv_cache_events = (
|
||||||
|
self.kv_events_config is not None
|
||||||
|
and self.kv_events_config.enable_kv_cache_events)
|
||||||
|
|
||||||
# Create KVConnector for the Scheduler. Note that each Worker
|
# Create KVConnector for the Scheduler. Note that each Worker
|
||||||
# will have a corresponding KVConnector with Role=WORKER.
|
# will have a corresponding KVConnector with Role=WORKER.
|
||||||
@ -71,6 +76,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.connector = KVConnectorFactory.create_connector_v1(
|
self.connector = KVConnectorFactory.create_connector_v1(
|
||||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||||
|
|
||||||
|
self.kv_event_publisher = EventPublisherFactory.create(
|
||||||
|
self.kv_events_config)
|
||||||
|
|
||||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||||
|
|
||||||
@ -132,7 +140,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
enable_caching=self.cache_config.enable_prefix_caching,
|
enable_caching=self.cache_config.enable_prefix_caching,
|
||||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
log_stats=self.log_stats)
|
log_stats=self.log_stats,
|
||||||
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
|
)
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
@ -493,6 +503,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
meta = self.connector.build_connector_meta(scheduler_output)
|
meta = self.connector.build_connector_meta(scheduler_output)
|
||||||
scheduler_output.kv_connector_metadata = meta
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
|
||||||
|
events = self.kv_cache_manager.take_events()
|
||||||
|
if events:
|
||||||
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
|
self.kv_event_publisher.publish(batch)
|
||||||
|
|
||||||
# Advance the number of computed tokens for the request AFTER
|
# Advance the number of computed tokens for the request AFTER
|
||||||
# the request is scheduled.
|
# the request is scheduled.
|
||||||
# 1. The scheduler_output of the current step has to include the
|
# 1. The scheduler_output of the current step has to include the
|
||||||
@ -843,3 +858,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_draft_tokens=num_draft_tokens,
|
num_draft_tokens=num_draft_tokens,
|
||||||
num_accepted_tokens=num_accepted_tokens)
|
num_accepted_tokens=num_accepted_tokens)
|
||||||
return spec_decoding_stats
|
return spec_decoding_stats
|
||||||
|
|
||||||
|
def shutdown(self) -> None:
|
||||||
|
if self.kv_event_publisher:
|
||||||
|
self.kv_event_publisher.shutdown()
|
||||||
|
|||||||
@ -259,6 +259,8 @@ class EngineCore:
|
|||||||
self.structured_output_manager.clear_backend()
|
self.structured_output_manager.clear_backend()
|
||||||
if self.model_executor:
|
if self.model_executor:
|
||||||
self.model_executor.shutdown()
|
self.model_executor.shutdown()
|
||||||
|
if self.scheduler:
|
||||||
|
self.scheduler.shutdown()
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
self.model_executor.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user