diff --git a/examples/online_serving/kv_events.sh b/examples/online_serving/kv_events.sh new file mode 100644 index 000000000000..a111db2179fc --- /dev/null +++ b/examples/online_serving/kv_events.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# This file demonstrates the KV cache event publishing +# We will launch a vllm instances configured to publish KV cache +# events and launch a simple subscriber to log those events. + +set -xe + +echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧" +sleep 1 + +MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +vllm serve $MODEL_NAME \ + --port 8100 \ + --max-model-len 100 \ + --enforce-eager \ + --gpu-memory-utilization 0.8 \ + --trust-remote-code \ + --kv-events-config \ + '{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' & + +wait_for_server 8100 + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +python3 "$SCRIPT_DIR/kv_events_subscriber.py" & +sleep 1 + +# serve two example requests +output1=$(curl -X POST -s http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "'"$MODEL_NAME"'", +"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.", +"max_tokens": 80, +"temperature": 0 +}') + +output2=$(curl -X POST -s http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "'"$MODEL_NAME"'", +"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.", +"max_tokens": 80, +"temperature": 0 +}') + +# Cleanup commands +pkill -9 -u "$USER" -f python +pkill -9 -u "$USER" -f vllm + +sleep 1 + +echo "Cleaned up" + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "πŸŽ‰πŸŽ‰ Successfully finished 2 test requests! πŸŽ‰πŸŽ‰" +echo "" diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py new file mode 100644 index 000000000000..88bbbebd7478 --- /dev/null +++ b/examples/online_serving/kv_events_subscriber.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional, Union + +import msgspec +import zmq +from msgspec.msgpack import Decoder + + +# +# Types copied from vllm.distributed.kv_events +# +class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, + gc=False): + ts: float + events: list[Any] + + +class KVCacheEvent(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False, + tag=True): + """Base class for all KV cache-related events""" + + +class BlockStored(KVCacheEvent): + block_hashes: list[int] + parent_block_hash: Optional[int] + token_ids: list[int] + block_size: int + lora_id: Optional[int] + + +class BlockRemoved(KVCacheEvent): + block_hashes: list[int] + + +class AllBlocksCleared(KVCacheEvent): + pass + + +class KVEventBatch(EventBatch): + events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + + +def process_event(event_batch): + print(f"Received event batch at {event_batch.ts}:") + for event in event_batch.events: + print(f" - {event}") + + +def main(): + decoder = Decoder(type=KVEventBatch) + last_seq = -1 + + context = zmq.Context() + + # Set up the main subscription socket + sub = context.socket(zmq.SUB) + sub.connect("tcp://localhost:5557") + topic = "kv-events" + sub.setsockopt_string(zmq.SUBSCRIBE, topic) + + # Initialize replay socket + replay = context.socket(zmq.REQ) + replay.connect("tcp://localhost:5558") + poller = zmq.Poller() + poller.register(replay, zmq.POLLIN) + + print("Listening for KV cache events on topic:", topic) + + while True: + try: + if sub.poll(50): + _, seq_bytes, payload = sub.recv_multipart() + seq = int.from_bytes(seq_bytes, "big") + + if last_seq >= 0 and seq > last_seq + 1: + missed = seq - last_seq - 1 + print(f"Missed {missed} messages" + f" (last: {last_seq}, current: {seq})") + + replay.send((last_seq + 1).to_bytes(8, "big")) + + while poller.poll(timeout=200): + seq_bytes, replay_payload = replay.recv_multipart() + if not replay_payload: + # End of replay marker is sent as an empty frame + # for the payload + break + + replay_seq = int.from_bytes(seq_bytes, "big") + + if replay_seq > last_seq: + event_batch = decoder.decode(replay_payload) + process_event(event_batch) + last_seq = replay_seq + if replay_seq >= seq - 1: + break + + event_batch = decoder.decode(payload) + process_event(event_batch) + + # ... do other periodic work or check for shutdown ... + + except KeyboardInterrupt: + print("Interrupted") + break + except Exception as e: + print("Error decoding message:", e) + + +if __name__ == "__main__": + main() diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py new file mode 100644 index 000000000000..ee8f2097933d --- /dev/null +++ b/tests/distributed/conftest.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +import random +from typing import Optional, Union + +import msgspec +import msgspec.msgpack +import pytest +import zmq + +from vllm.config import KVEventsConfig +from vllm.distributed.kv_events import EventPublisherFactory + +from .test_events import SampleBatch + + +@pytest.fixture +def random_port(): + """Generate a random port number for testing""" + return random.randint(10000, 60000) + + +@pytest.fixture +def publisher_config(random_port, request): + """Create a publisher config with inproc transport""" + how = request.param if hasattr(request, "param") else "inproc" + + if how == "inproc": + endpoint = f"inproc://test-{random_port}" + replay_endpoint = endpoint + "-replay" + else: + endpoint = f"tcp://*:{random_port}" + replay_endpoint = f"tcp://*:{random_port + 1}" + + return KVEventsConfig(enable_kv_cache_events=True, + publisher="zmq", + endpoint=endpoint, + replay_endpoint=replay_endpoint, + buffer_steps=100, + hwm=1000, + topic="test") + + +@pytest.fixture +def publisher(publisher_config): + """Create and return a publisher instance""" + pub = EventPublisherFactory.create(publisher_config) + yield pub + pub.shutdown() + + +@pytest.fixture +def subscriber(publisher_config): + """Create and return a subscriber for testing""" + endpoint = publisher_config.endpoint + replay_endpoint = publisher_config.replay_endpoint + + if endpoint.startswith("tcp://*"): + endpoint = endpoint.replace("*", "127.0.0.1") + if replay_endpoint and replay_endpoint.startswith("tcp://*"): + replay_endpoint = replay_endpoint.replace("*", "127.0.0.1") + + sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic) + yield sub + sub.close() + + +class MockSubscriber: + """Helper class to receive and verify published events""" + + def __init__(self, + pub_endpoint: str, + replay_endpoint: Optional[str] = None, + topic: str = "", + decode_type=SampleBatch): + self.ctx = zmq.Context.instance() + + # Set up subscriber socket + self.sub = self.ctx.socket(zmq.SUB) + self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8')) + self.sub.connect(pub_endpoint) + + # Set up replay socket if provided + self.replay = None + if replay_endpoint: + self.replay = self.ctx.socket(zmq.REQ) + self.replay.connect(replay_endpoint) + + self.topic = topic + self.topic_bytes = topic.encode('utf-8') + self.received_msgs: list[tuple[int, SampleBatch]] = [] + self.last_seq = -1 + self.decoder = msgspec.msgpack.Decoder(type=decode_type) + + def receive_one(self, + timeout=1000) -> Union[tuple[int, SampleBatch], None]: + """Receive a single message with timeout""" + if not self.sub.poll(timeout): + return None + + topic_bytes, seq_bytes, payload = self.sub.recv_multipart() + assert topic_bytes == self.topic_bytes + + seq = int.from_bytes(seq_bytes, "big") + data = self.decoder.decode(payload) + self.last_seq = seq + self.received_msgs.append((seq, data)) + return seq, data + + def request_replay(self, start_seq: int) -> None: + """Request replay of messages starting from start_seq""" + if not self.replay: + raise ValueError("Replay socket not initialized") + + self.replay.send(start_seq.to_bytes(8, "big")) + + def receive_replay(self) -> list[tuple[int, SampleBatch]]: + """Receive replayed messages""" + if not self.replay: + raise ValueError("Replay socket not initialized") + + replayed: list[tuple[int, SampleBatch]] = [] + while True: + try: + if not self.replay.poll(1000): + break + + frames = self.replay.recv_multipart() + if not frames or not frames[-1]: + # End of replay marker + break + + seq_bytes, payload = frames + seq = int.from_bytes(seq_bytes, "big") + data = self.decoder.decode(payload) + replayed.append((seq, data)) + except zmq.ZMQError as _: + break + + return replayed + + def close(self): + """Clean up resources""" + self.sub.close() + if self.replay: + self.replay.close() diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py new file mode 100644 index 000000000000..15bcfdb8555f --- /dev/null +++ b/tests/distributed/test_events.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +import threading +import time + +import msgspec +import pytest + +from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, + NullEventPublisher) + + +class EventSample( + msgspec.Struct, + tag=True, # type: ignore + array_like=True # type: ignore +): + """Test event for publisher testing""" + id: int + value: str + + +class SampleBatch(EventBatch): + """Test event batch for publisher testing""" + events: list[EventSample] + + +def create_test_events(count: int) -> SampleBatch: + """Create a batch of test events""" + events = [EventSample(id=i, value=f"test-{i}") for i in range(count)] + return SampleBatch(ts=time.time(), events=events) + + +def test_basic_publishing(publisher, subscriber): + """Test basic event publishing works""" + + test_batch = create_test_events(5) + publisher.publish(test_batch) + + result = subscriber.receive_one(timeout=1000) + assert result is not None, "No message received" + + seq, received = result + assert seq == 0, "Sequence number mismatch" + assert received.ts == pytest.approx(test_batch.ts, + abs=0.1), ("Timestamp mismatch") + assert len(received.events) == len( + test_batch.events), ("Number of events mismatch") + + for i, event in enumerate(received.events): + assert event.id == i, "Event id mismatch" + assert event.value == f"test-{i}", "Event value mismatch" + + +def test_multiple_events(publisher, subscriber): + """Test publishing and receiving multiple event batches""" + for _ in range(10): + batch = create_test_events(2) + publisher.publish(batch) + + received = [] + for _ in range(10): + data = subscriber.receive_one(timeout=100) + if data: + received.append(data) + + assert len(received) == 10, "Number of messages mismatch" + seqs = [seq for seq, _ in received] + assert seqs == list(range(10)), "Sequence numbers mismatch" + + +def test_replay_mechanism(publisher, subscriber): + """Test the replay mechanism works correctly""" + for _ in range(19): + batch = create_test_events(1) + publisher.publish(batch) + + time.sleep(0.5) # Need publisher to process above requests + subscriber.request_replay(10) + + batch = create_test_events(1) + publisher.publish(batch) # 20th message + + replayed = subscriber.receive_replay() + + assert len(replayed) > 0, "No replayed messages received" + seqs = [seq for seq, _ in replayed] + assert all(seq >= 10 for seq in seqs), "Replayed messages not in order" + assert seqs == list(range(min(seqs), + max(seqs) + + 1)), ("Replayed messages not consecutive") + + +def test_buffer_limit(publisher, subscriber, publisher_config): + """Test buffer limit behavior""" + buffer_size = publisher_config.buffer_steps + + # Publish more events than the buffer can hold + for i in range(buffer_size + 10): + batch = create_test_events(1) + publisher.publish(batch) + + time.sleep(0.5) # Need publisher to process above requests + subscriber.request_replay(0) + + batch = create_test_events(1) + publisher.publish(batch) + + replayed = subscriber.receive_replay() + + assert len(replayed) <= buffer_size, "Can't replay more than buffer size" + + oldest_seq = min(seq for seq, _ in replayed) + assert oldest_seq >= 10, "The oldest sequence should be at least 10" + + +def test_topic_filtering(publisher_config): + """ + Test that a subscriber only receives messages matching its topic filter + """ + publisher_config.replay_endpoint = None + + cfg = publisher_config.model_copy() + cfg.topic = "foo" + pub = EventPublisherFactory.create(cfg) + + from .conftest import MockSubscriber + sub_foo = MockSubscriber(cfg.endpoint, None, "foo") + sub_bar = MockSubscriber(cfg.endpoint, None, "bar") + + try: + time.sleep(0.1) + + for _ in range(3): + pub.publish(create_test_events(1)) + + foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)] + assert all(msg is not None for msg in foo_received), ( + "Subscriber with matching topic should receive messages") + + bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)] + assert all(msg is None for msg in bar_received), ( + "Subscriber with non-matching topic should receive no messages") + finally: + pub.shutdown() + sub_foo.close() + sub_bar.close() + + +def test_high_volume(publisher, subscriber): + """Test publishing and receiving a high volume of events""" + num_batches = 10_000 + events_per_batch = 100 + + # Publish events in a separate thread to not block + def publish_events(): + for i in range(num_batches): + batch = create_test_events(events_per_batch) + publisher.publish(batch) + # Small delay to avoid overwhelming + if i % 100 == 0: + time.sleep(0.01) + + received: list[tuple[int, SampleBatch]] = [] + + publisher_thread = threading.Thread(target=publish_events) + publisher_thread.start() + + start_time = time.time() + while len(received) < num_batches: + if time.time() - start_time > 10: # Timeout after 10 seconds + break + + result = subscriber.receive_one(timeout=100) + if result: + received.append(result) + + publisher_thread.join() + + assert len(received) >= num_batches * 0.9, ( + "We should have received most messages") + + seqs = [seq for seq, _ in received] + assert sorted(seqs) == seqs, "Sequence numbers should be in order" + + +def test_null_publisher(): + """Test that NullEventPublisher can be used without errors""" + publisher = NullEventPublisher() + + # This should not raise any errors + batch = create_test_events(5) + publisher.publish(batch) + publisher.shutdown() diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index ae4bd95d22aa..af0fef89d15c 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -6,6 +6,7 @@ from typing import Optional import pytest import torch +from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.utils import sha256 @@ -48,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: num_blocks=num_blocks, tensors={}, kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) + KVCacheGroupSpec( + ["layer"], + FullAttentionSpec(block_size, 1, 1, torch.float32, False), + ) ], ) @@ -783,6 +785,60 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None +@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) +def test_kv_cache_events(blocks_to_cache: int): + block_size = 16 + num_blocks = blocks_to_cache + 1 + + # Allocate Blocks + # Should see a single block stored event with a blocks_to_cache number of + # block hashes + # take_events should reset the kv_event_queue + manager = KVCacheManager( + make_kv_cache_config(block_size, num_blocks), + max_model_len=8192, + enable_caching=True, + enable_kv_cache_events=True, + ) + + num_tokens = block_size * blocks_to_cache + req0 = make_request("0", list(range(num_tokens))) + _ = manager.allocate_slots(req0, num_tokens) + events = manager.take_events() + + block = events[-1] + assert (len(block.block_hashes) == blocks_to_cache == len( + manager.block_pool.cached_block_hash_to_block)) + assert len(block.token_ids) == block.block_size * len(block.block_hashes) + assert len(manager.block_pool.kv_event_queue) == 0 + + stored_block_hash = block.block_hashes + + # Remove blocks and send another request + # Should see block_to_cache number of removed block events and a new block + # stored event + manager.free(req0) + req1 = make_request("1", list(range(num_tokens))) + _ = manager.allocate_slots(req1, num_tokens) + events = manager.take_events() + + for blocks in events[:-1]: + assert blocks.block_hashes[0] in stored_block_hash + assert len(events) == blocks_to_cache + 1 + assert (isinstance(events[-2], BlockRemoved)) + assert (len(events[-1].block_hashes) == blocks_to_cache == len( + manager.block_pool.cached_block_hash_to_block)) + + # All Blocks Cleared + # Should see a single all blocks cleared event + manager.free(req1) + manager.reset_prefix_cache() + events = manager.take_events() + + assert isinstance(events[-1], AllBlocksCleared) + assert len(manager.block_pool.cached_block_hash_to_block) == 0 + + def test_eagle_enabled_removes_last_block(): """Verify Eagle does NOT remove blocks when request length is divisible by block size.""" diff --git a/tests/v1/engine/conftest.py b/tests/v1/engine/conftest.py index f8addd920d57..d04679c12448 100644 --- a/tests/v1/engine/conftest.py +++ b/tests/v1/engine/conftest.py @@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST, from vllm.engine.arg_utils import EngineArgs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from ...distributed.conftest import publisher_config, random_port # noqa: F401 + from tests.v1.engine.utils import FULL_STRINGS # isort: skip EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]] diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 5514a328497f..3e1aa56882a8 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -11,6 +11,7 @@ import pytest from transformers import AutoTokenizer from vllm import SamplingParams +from vllm.distributed.kv_events import BlockStored, KVEventBatch from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext @@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, SyncMPClient) from vllm.v1.executor.abstract import Executor +from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test if not current_platform.is_cuda(): @@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): log_stats=True, ) - MAX_TOKENS = 20 - params = SamplingParams(max_tokens=MAX_TOKENS) - """Normal Request Cycle.""" + try: + MAX_TOKENS = 20 + params = SamplingParams(max_tokens=MAX_TOKENS) + """Normal Request Cycle.""" - requests = [make_request(params) for _ in range(10)] - request_ids = [req.request_id for req in requests] + requests = [make_request(params) for _ in range(10)] + request_ids = [req.request_id for req in requests] - # Add requests to the engine. - for request in requests: - await client.add_request_async(request) - await asyncio.sleep(0.01) + # Add requests to the engine. + for request in requests: + await client.add_request_async(request) + await asyncio.sleep(0.01) - outputs: dict[str, list] = {req_id: [] for req_id in request_ids} - await loop_until_done_async(client, outputs) + outputs: dict[str, list] = {req_id: [] for req_id in request_ids} + await loop_until_done_async(client, outputs) - for req_id in request_ids: - assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{outputs[req_id]=}, {MAX_TOKENS=}") - """Abort Request Cycle.""" - - # Add requests to the engine. - for idx, request in enumerate(requests): - await client.add_request_async(request) - await asyncio.sleep(0.01) - if idx % 2 == 0: - await client.abort_requests_async([request.request_id]) - - outputs = {req_id: [] for req_id in request_ids} - await loop_until_done_async(client, outputs) - - for idx, req_id in enumerate(request_ids): - if idx % 2 == 0: - assert len(outputs[req_id]) < MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") - else: + for req_id in request_ids: assert len(outputs[req_id]) == MAX_TOKENS, ( - f"{len(outputs[req_id])=}, {MAX_TOKENS=}") - """Utility method invocation""" + f"{outputs[req_id]=}, {MAX_TOKENS=}") + """Abort Request Cycle.""" - core_client: AsyncMPClient = client + # Add requests to the engine. + for idx, request in enumerate(requests): + await client.add_request_async(request) + await asyncio.sleep(0.01) + if idx % 2 == 0: + await client.abort_requests_async([request.request_id]) - result = await core_client.call_utility_async("echo", "testarg") - assert result == "testarg" + outputs = {req_id: [] for req_id in request_ids} + await loop_until_done_async(client, outputs) - with pytest.raises(Exception) as e_info: - await core_client.call_utility_async("echo", None, "help!") + for idx, req_id in enumerate(request_ids): + if idx % 2 == 0: + assert len(outputs[req_id]) < MAX_TOKENS, ( + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + else: + assert len(outputs[req_id]) == MAX_TOKENS, ( + f"{len(outputs[req_id])=}, {MAX_TOKENS=}") + """Utility method invocation""" - assert str(e_info.value) == "Call to echo method failed: help!" + core_client: AsyncMPClient = client + + result = await core_client.call_utility_async("echo", "testarg") + assert result == "testarg" + + with pytest.raises(Exception) as e_info: + await core_client.call_utility_async("echo", None, "help!") + + assert str(e_info.value) == "Call to echo method failed: help!" + finally: + client.shutdown() + + +@pytest.mark.parametrize( + "multiprocessing_mode,publisher_config", + [(True, "tcp"), (False, "inproc")], + indirect=["publisher_config"], +) +def test_kv_cache_events( + monkeypatch: pytest.MonkeyPatch, + multiprocessing_mode: bool, + publisher_config, +): + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + block_size = 16 + num_blocks = 2 + + engine_args = EngineArgs(model=MODEL_NAME, + enforce_eager=True, + enable_prefix_caching=True, + block_size=block_size) + engine_args.kv_events_config = publisher_config + + vllm_config = engine_args.create_engine_config( + UsageContext.UNKNOWN_CONTEXT) + + executor_class = Executor.get_class(vllm_config) + client = EngineCoreClient.make_client( + multiprocess_mode=multiprocessing_mode, + asyncio_mode=False, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=False, + ) + endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") + time.sleep(0.1) + subscriber = MockSubscriber(endpoint, + topic=publisher_config.topic, + decode_type=KVEventBatch) + + try: + custom_tokens = list(range(num_blocks * block_size)) + request = EngineCoreRequest( + request_id=str(uuid.uuid4()), + prompt_token_ids=custom_tokens, + mm_inputs=None, + mm_hashes=None, + mm_placeholders=None, + sampling_params=SamplingParams( + max_tokens=1), # Short completion for speed + eos_token_id=None, + arrival_time=time.time(), + lora_request=None, + ) + client.add_request(request) + + outputs: dict[str, list] = {request.request_id: []} + loop_until_done(client, outputs) + + result = subscriber.receive_one(timeout=1000) + assert result is not None, "No message received" + + seq, received = result + + assert seq == 0, "Sequence number mismatch" + assert len(received.events) == 1, ( + "We should have exactly one BlockStored event") + event = received.events[0] + assert isinstance( + event, BlockStored), ("We should have a BlockStored event") + assert len(event.block_hashes) == num_blocks, ( + "We should have a BlockStored event with 2 block_hashes") + assert event.block_size == block_size, ( + "Block size should be the same as the block size") + assert event.parent_block_hash is None, ( + "Parent block hash should be None") + assert event.lora_id is None, "Lora id should be None" + assert len(event.token_ids) == num_blocks * block_size, ( + "Token ids should be the same as the custom tokens") + assert event.token_ids == custom_tokens, ( + "Token ids should be the same as the custom tokens") + finally: + client.shutdown() + return @pytest.mark.timeout(10) diff --git a/vllm/config.py b/vllm/config.py index f9c5e25a47d4..5da1ab2587d0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1958,6 +1958,8 @@ class SchedulerConfig: some image tokens can be scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled as TTTT in one step and IIIIIIIIII in the next.""" + # scheduler class or path. "vllm.core.scheduler.Scheduler" (default) + # or "mod.custom_class". scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler" """The scheduler class to use. "vllm.core.scheduler.Scheduler" is the default scheduler. Can be a class directly or the path to a class of form @@ -3417,6 +3419,51 @@ class KVTransferConfig(BaseModel): return self.kv_connector_extra_config.get(key, default) +class KVEventsConfig(BaseModel): + """Configuration for KV event publishing.""" + + enable_kv_cache_events: bool = False + """If True, enable KV cache events for tracking block storage and removal. + Events can be published externally by zmq using the event publisher config. + """ + + publisher: str = "null" + """The publisher to use for publishing kv events. Can be "null", "zmq". + """ + + endpoint: str = "tcp://*:5557" + """The zmq endpoint to use for publishing kv events. + """ + + replay_endpoint: Optional[str] = None + """The zmq endpoint to use for replaying kv events. + """ + + buffer_steps: int = 10_000 + """The number of steps to cache for replay endpoint. Will only save + events from the last N steps for the replay endpoint. + """ + + hwm: int = 100_000 + """The zmq high water mark for the event publisher. After queueing N events, + events will start dropping if the consumer is not keeping up. + """ + + max_queue_size: int = 100_000 + """The maximum number of events to queue while waiting for publishing. + """ + + topic: str = "" + """The topic to use for the event publisher. Consumers can subscribe to + this topic to receive events. + """ + + @classmethod + def from_cli(cls, cli_value: str) -> "KVEventsConfig": + """Parse the CLI value for the event publisher config.""" + return KVEventsConfig.model_validate_json(cli_value) + + class CompilationLevel: # constants for the levels of the compilation process NO_COMPILATION = 0 @@ -3779,6 +3826,7 @@ class VllmConfig: init=True) # type: ignore kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + kv_events_config: Optional[KVEventsConfig] = None # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -4038,6 +4086,18 @@ class VllmConfig: if self.cache_config is not None: self.cache_config.enable_prefix_caching = False + if (self.kv_events_config + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching): + logger.warning( + "KV cache events are on, but prefix caching is not enabled." + "Use --enable-prefix-caching to enable.") + if (self.kv_events_config and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events): + logger.warning("KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable.") current_platform.check_and_update_config(self) if not self.instance_id: diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py new file mode 100644 index 000000000000..960913858527 --- /dev/null +++ b/vllm/distributed/kv_events.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: Apache-2.0 + +import queue +import threading +import time +from abc import ABC, abstractmethod +from collections import deque +from itertools import count +from queue import Queue +from typing import Any, Callable, Optional, Union + +import msgspec +import zmq + +from vllm.config import KVEventsConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class EventBatch( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] +): + ts: float + events: list[Any] + + +class KVCacheEvent( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + gc=False, # type: ignore[call-arg] + tag=True): + """Base class for all KV cache-related events""" + + +class BlockStored(KVCacheEvent): + block_hashes: list[int] + parent_block_hash: Optional[int] + token_ids: list[int] + block_size: int + lora_id: Optional[int] + + +class BlockRemoved(KVCacheEvent): + block_hashes: list[int] + + +class AllBlocksCleared(KVCacheEvent): + pass + + +class KVEventBatch(EventBatch): + events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + + +class EventPublisher(ABC): + """Lightweight publisher for EventBatch batches.""" + + @abstractmethod + def publish(self, events: EventBatch) -> None: + """Emit events in order. + + Implementations should guarantee at-least-once delivery and + monotonic ordering (e.g., via sequence numbers). + """ + + @abstractmethod + def shutdown(self) -> None: + """Shutdown the publisher.""" + + +class NullEventPublisher(EventPublisher): + """No-op implementation (default when disabled).""" + + def publish(self, events) -> None: + return + + def shutdown(self) -> None: + return + + +class ZmqEventPublisher(EventPublisher): + """Reliable PUB/ROUTER publisher with an in-memory replay buffer. + + Spawns a separate thread to handle publishing from a queue. + + Parameters + ---------- + endpoint: + PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to + connect. + replay_endpoint: + Optional ROUTER address for replay requests. When given, subscribers can + request missed batches by sending the starting sequence number as an + 8-byte big-endian integer. + buffer_steps: + Number of past batches to keep for replay. + hwm: + ZeroMQ high-water-mark for PUB socket. + max_queue_size: + Maximum number of events to buffer in memory. + topic: + Topic to publish events to. + """ + SHUTDOWN_TIMEOUT: float = 1.0 + END_SEQ = (-1).to_bytes(8, "big", signed=True) + + def __init__( + self, + endpoint: str = "tcp://*:5557", + replay_endpoint: Optional[str] = None, + buffer_steps: int = 10_000, + hwm: int = 100_000, + max_queue_size: int = 100_000, + topic: str = "", + ) -> None: + # Storage + self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size) + self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps) + + # ZMQ sockets + self._ctx = zmq.Context.instance() + self._pub: Optional[zmq.Socket] = None + self._replay: Optional[zmq.Socket] = None + self._endpoint = endpoint + self._replay_endpoint = replay_endpoint + self._hwm = hwm + + # Payload + self._seq_gen = count() + self._topic_bytes = topic.encode('utf-8') + + # Thread + self._running = True + logger.info("Starting ZMQ publisher thread") + + self._thread = threading.Thread(target=self._publisher_thread, + daemon=True, + name="zmq-publisher") + self._thread.start() + + def publish(self, events: EventBatch) -> None: + if not self._running: + raise RuntimeError("Publisher is closed") + self._event_queue.put(events) + + def shutdown(self) -> None: + """Stop the publisher thread and clean up resources.""" + self._running = False + self._event_queue.put_nowait(None) + + start = time.time() + pending_items = True + while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT): + pending_items = not self._event_queue.empty() + if pending_items: + time.sleep(0.1) + + if pending_items: + logger.warning( + "Warning: Queue still has %s items after %s seconds timeout", + self._event_queue.qsize(), + self.SHUTDOWN_TIMEOUT, + ) + + if self._thread.is_alive(): + self._thread.join(timeout=self.SHUTDOWN_TIMEOUT) + + # Clean up ZMQ resources + try: + if self._pub is not None: + self._pub.close(linger=0) + if self._replay is not None: + self._replay.close(linger=0) + finally: + pass # Do not terminate context; other sockets may use it + + def _socket_setup(self) -> None: + """Initialize sockets + https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety + """ + if self._pub is None: + self._pub = self._ctx.socket(zmq.PUB) + self._pub.set_hwm(self._hwm) + # Heuristic: bind if wildcard / * present, else connect. + # bind stable, connect volatile convention + if ("*" in self._endpoint or "::" in self._endpoint + or self._endpoint.startswith("ipc://") + or self._endpoint.startswith("inproc://")): + self._pub.bind(self._endpoint) + else: + self._pub.connect(self._endpoint) + + # Set up replay socket: use ROUTER + # 1) handles multiple REQ clients (identities) + # 2) lets us send back one request β†’ many replies (streamed events) + # 3) works in our non‑blocking poll loop alongside PUB + if self._replay_endpoint is not None: + self._replay = self._ctx.socket(zmq.ROUTER) + self._replay.bind(self._replay_endpoint) + + def _publisher_thread(self) -> None: + """Background thread that processes the event queue.""" + self._pack = msgspec.msgpack.Encoder() + self._socket_setup() + + assert self._pub is not None # narrows type for mypy + + while self._running or self._event_queue.qsize() > 0: + # --- replay (non-critical) --------------------------------- + if self._replay is not None and self._replay.poll(0): + try: + self._service_replay() + except Exception as e: + logger.exception("Error in replay: %s", e) + + # --- main queue (critical) --------------------------------- + try: + event = self._event_queue.get(timeout=0.1) + if event is None: + break # Sentinel received, exit thread + except queue.Empty: + continue + + try: + seq = next(self._seq_gen) + + payload = self._pack.encode(event) + seq_bytes = seq.to_bytes(8, "big") + self._pub.send_multipart( + (self._topic_bytes, seq_bytes, payload)) + + self._buffer.append((seq, payload)) + self._event_queue.task_done() + + except Exception as e: + # Publishing failed; back-off a bit to avoid a tight error loop + logger.exception("Error in publisher thread: %s", e) + time.sleep(0.1) + + def _service_replay(self) -> None: + """If a replay request is waiting, send buffered batches.""" + assert self._replay is not None # narrows type for mypy + + frame = self._replay.recv_multipart() + if len(frame) != 3: + logger.warning("Invalid replay request: %s", frame) + return + client_id, _, start_seq_bytes = frame + start_seq = int.from_bytes(start_seq_bytes, "big") + + for seq, buf in self._buffer: + if seq >= start_seq: + # [identity, empty_delim, seq_bytes, payload] + # (identity, empty_delim) are stripped off by the router + # receiving payload is (seq_bytes, payload) + self._replay.send_multipart( + (client_id, b"", seq.to_bytes(8, "big"), buf)) + # Send end of sequence marker + # receiving payload is (-1, b""") + self._replay.send_multipart((client_id, b"", self.END_SEQ, b"")) + + +class EventPublisherFactory: + _registry: dict[str, Callable[..., EventPublisher]] = { + "null": NullEventPublisher, + "zmq": ZmqEventPublisher, + } + + @classmethod + def register_publisher(cls, name: str, + ctor: Callable[..., EventPublisher]) -> None: + if name in cls._registry: + raise KeyError(f"publisher '{name}' already registered") + cls._registry[name] = ctor + + @classmethod + def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher: + """Create publisher from a config mapping.""" + if not config: + return NullEventPublisher() + + config_dict = config.model_dump() + + kind = config_dict.pop("publisher", "null") + config_dict.pop("enable_kv_cache_events") + try: + constructor = cls._registry[kind] + except KeyError as exc: + raise ValueError(f"Unknown event publisher '{kind}'") from exc + return constructor(**config_dict) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4f074fcd1b8e..c7a580cf1051 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -19,14 +19,14 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, ConfigFormat, ConfigType, DecodingConfig, Device, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, - HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, - LoRAConfig, ModelConfig, ModelDType, ModelImpl, - MultiModalConfig, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, - PromptAdapterConfig, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, TaskOption, TokenizerMode, - TokenizerPoolConfig, VllmConfig, get_attr_docs, - get_field) + HfOverrides, KVEventsConfig, KVTransferConfig, + LoadConfig, LoadFormat, LoRAConfig, ModelConfig, + ModelDType, ModelImpl, MultiModalConfig, + ObservabilityConfig, ParallelConfig, PoolerConfig, + PrefixCachingHashAlgo, PromptAdapterConfig, + SchedulerConfig, SchedulerPolicy, SpeculativeConfig, + TaskOption, TokenizerMode, TokenizerPoolConfig, + VllmConfig, get_attr_docs, get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationMethods @@ -353,6 +353,7 @@ class EngineArgs: worker_extension_cls: str = ParallelConfig.worker_extension_cls kv_transfer_config: Optional[KVTransferConfig] = None + kv_events_config: Optional[KVEventsConfig] = None generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode @@ -769,6 +770,10 @@ class EngineArgs: default=None, help='The configurations for distributed KV cache ' 'transfer. Should be a JSON string.') + parser.add_argument('--kv-events-config', + type=KVEventsConfig.from_cli, + default=None, + help='The configurations for event publishing.') parser.add_argument( '--worker-cls', @@ -1125,6 +1130,7 @@ class EngineArgs: prompt_adapter_config=prompt_adapter_config, compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, + kv_events_config=self.kv_events_config, additional_config=self.additional_config, ) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 74f3f7852c9a..f2ed183b68fc 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,6 +3,8 @@ from collections import defaultdict from collections.abc import Iterable from typing import Callable, Optional +from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, + BlockStored, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, KVCacheBlock, @@ -26,7 +28,12 @@ class BlockPool: enable_caching: Whether to enable prefix caching. """ - def __init__(self, num_gpu_blocks: int, enable_caching: bool): + def __init__( + self, + num_gpu_blocks: int, + enable_caching: bool, + enable_kv_cache_events: bool = False, + ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching @@ -56,6 +63,9 @@ class BlockPool: # avoid freeing it. self.null_block = self.free_block_queue.popleft() + self.enable_kv_cache_events = enable_kv_cache_events + self.kv_event_queue: list[KVCacheEvent] = [] + def get_cached_block(self, block_hash: BlockHashType) -> Optional[KVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. @@ -116,6 +126,9 @@ class BlockPool: assert prev_block.block_hash is not None prev_block_hash_value = prev_block.block_hash.hash_value + parent_block_hash = prev_block_hash_value + new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events + else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None @@ -153,8 +166,23 @@ class BlockPool: # Update and added the full block to the cache. blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + if new_hashes is not None: + new_hashes.append(block_hash.hash_value) prev_block_hash_value = block_hash.hash_value + if self.enable_kv_cache_events: + self.kv_event_queue.append( + BlockStored( + block_hashes=new_hashes, + parent_block_hash=parent_block_hash, + token_ids=request. + all_token_ids[num_cached_blocks * + block_size:num_full_blocks * block_size], + block_size=block_size, + lora_id=request.lora_request.id + if request.lora_request else None, + )) + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -206,6 +234,9 @@ class BlockPool: if len(self.cached_block_hash_to_block[block_hash]) == 0: del self.cached_block_hash_to_block[block_hash] + if self.enable_kv_cache_events: + self.kv_event_queue.append( + BlockRemoved(block_hashes=[block_hash.hash_value])) return True return False @@ -262,6 +293,10 @@ class BlockPool: block.reset_hash() logger.info("Successfully reset prefix cache") + + if self.enable_kv_cache_events: + self.kv_event_queue.append(AllBlocksCleared()) + return True def get_num_free_blocks(self) -> int: @@ -279,3 +314,15 @@ class BlockPool: The KV cache usage (between 0.0 and 1.0). """ return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks) + + def take_events(self) -> list[KVCacheEvent]: + """Atomically takes all events and clears the queue. + + Returns: + A list of KV cache events. + """ + if not self.enable_kv_cache_events: + return [] + events = self.kv_event_queue + self.kv_event_queue = [] + return events diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0830d8433d89..39554bed0fcf 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,6 +4,7 @@ from collections import defaultdict from collections.abc import Iterable from typing import Optional +from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool @@ -27,6 +28,7 @@ class KVCacheManager: caching_hash_algo: str = "builtin", use_eagle: bool = False, log_stats: bool = False, + enable_kv_cache_events: bool = False, ) -> None: assert len(kv_cache_config.kv_cache_groups) == 1, ( "KVCacheManager does not support hybrid models with more than 1 " @@ -44,7 +46,9 @@ class KVCacheManager: # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, + enable_kv_cache_events) + self.specialized_manager = get_specialized_manager( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, @@ -383,3 +387,11 @@ class KVCacheManager: is finished, not when it is preempted. """ self.req_to_block_hashes.pop(request.request_id, None) + + def take_events(self) -> list[KVCacheEvent]: + """Take the KV cache events from the block pool. + + Returns: + A list of KV cache events. + """ + return self.block_pool.take_events() diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 1de236d42f02..0b328f510903 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -132,3 +132,8 @@ class SchedulerInterface(ABC): The SchedulerStats object is created for every scheduling step. """ raise NotImplementedError + + @abstractmethod + def shutdown(self) -> None: + """Shutdown the scheduler.""" + raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7ebbb4954f51..ae7280a14706 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -8,6 +8,7 @@ from collections.abc import Iterable from typing import Optional, Union from vllm.config import VllmConfig +from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole @@ -48,6 +49,7 @@ class Scheduler(SchedulerInterface): self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.kv_cache_config = kv_cache_config + self.kv_events_config = vllm_config.kv_events_config self.log_stats = log_stats self.structured_output_manager = structured_output_manager @@ -62,6 +64,9 @@ class Scheduler(SchedulerInterface): self.max_num_scheduled_tokens = \ self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len + self.enable_kv_cache_events = ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -71,6 +76,9 @@ class Scheduler(SchedulerInterface): self.connector = KVConnectorFactory.create_connector_v1( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + self.kv_event_publisher = EventPublisherFactory.create( + self.kv_events_config) + num_gpu_blocks = self.cache_config.num_gpu_blocks assert num_gpu_blocks is not None and num_gpu_blocks > 0 @@ -132,7 +140,9 @@ class Scheduler(SchedulerInterface): enable_caching=self.cache_config.enable_prefix_caching, caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, - log_stats=self.log_stats) + log_stats=self.log_stats, + enable_kv_cache_events=self.enable_kv_cache_events, + ) def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -493,6 +503,11 @@ class Scheduler(SchedulerInterface): meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta + events = self.kv_cache_manager.take_events() + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + # Advance the number of computed tokens for the request AFTER # the request is scheduled. # 1. The scheduler_output of the current step has to include the @@ -843,3 +858,7 @@ class Scheduler(SchedulerInterface): num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens) return spec_decoding_stats + + def shutdown(self) -> None: + if self.kv_event_publisher: + self.kv_event_publisher.shutdown() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5912318f19ff..e772615b7861 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -259,6 +259,8 @@ class EngineCore: self.structured_output_manager.clear_backend() if self.model_executor: self.model_executor.shutdown() + if self.scheduler: + self.scheduler.shutdown() def profile(self, is_start: bool = True): self.model_executor.profile(is_start)