From 111faf11185a75435403e3843e9a21a57cf84ccb Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Tue, 28 Oct 2025 23:01:33 +0200 Subject: [PATCH] [Core] Scheduler: Publish connector events after output (#25875) Signed-off-by: Or Ozeri --- tests/v1/kv_offload/test_cpu_offloading.py | 144 +++++++++++++++++---- vllm/v1/core/sched/scheduler.py | 34 ++--- 2 files changed, 135 insertions(+), 43 deletions(-) diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index e9c255b1ee994..b654ea4298dbb 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -1,15 +1,68 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import socket import time +import msgspec +import msgspec.msgpack import pytest +import zmq +from tqdm import tqdm -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig +from vllm import LLM, SamplingParams, TokensPrompt +from vllm.config import KVEventsConfig, KVTransferConfig +from vllm.distributed.kv_events import BlockStored, KVEventBatch CPU_BLOCK_SIZES = [16, 48] +class MockSubscriber: + """Helper class to receive and verify published events""" + + def __init__( + self, + endpoint: str, + topic: str, + ): + self.ctx = zmq.Context.instance() + self.topic_bytes = topic.encode("utf-8") + + # Set up subscriber socket + self.sub = self.ctx.socket(zmq.SUB) + self.sub.setsockopt(zmq.SUBSCRIBE, self.topic_bytes) + self.sub.connect(endpoint) + + self.decoder = msgspec.msgpack.Decoder(type=KVEventBatch) + + def get_new_cpu_stored_events(self) -> list[BlockStored]: + cpu_stored_events: list[BlockStored] = [] + + poller = zmq.Poller() + poller.register(self.sub, zmq.POLLIN) + + timeout = 1000 # 1 second + while True: + events = dict(poller.poll(timeout)) + + if events.get(self.sub) != zmq.POLLIN: + return cpu_stored_events + + topic_bytes, _, payload = self.sub.recv_multipart() + + assert topic_bytes == self.topic_bytes + + event_batch = self.decoder.decode(payload) + assert isinstance(event_batch, KVEventBatch) + for event in event_batch.events: + if isinstance(event, BlockStored) and event.medium == "CPU": + cpu_stored_events.append(event) + timeout = 100 + + def close(self): + """Clean up resources""" + self.sub.close() + + @pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) def test_cpu_offloading(cpu_block_size: int) -> None: """ @@ -20,41 +73,80 @@ def test_cpu_offloading(cpu_block_size: int) -> None: kv_transfer_config = KVTransferConfig( kv_connector="OffloadingConnector", kv_role="kv_both", - kv_connector_extra_config={"num_cpu_blocks": 100, "block_size": cpu_block_size}, + kv_connector_extra_config={ + "num_cpu_blocks": 1000, + "block_size": cpu_block_size, + }, + ) + + port: int + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("0.0.0.0", 0)) + port = s.getsockname()[1] + + events_endpoint = f"tcp://*:{port}" + kv_events_config = KVEventsConfig( + enable_kv_cache_events=True, + publisher="zmq", + endpoint=events_endpoint, + topic="test", ) llm = LLM( model="meta-llama/Llama-3.2-1B-Instruct", gpu_memory_utilization=0.5, + kv_events_config=kv_events_config, kv_transfer_config=kv_transfer_config, - disable_hybrid_kv_cache_manager=True, ) - prompts = ["Hi " * 100] - sampling_params = SamplingParams(temperature=0, max_tokens=20) + sampling_params = SamplingParams(temperature=0, max_tokens=1) - # run generation - this should trigger saving KV cache - start_time = time.time() - llm.generate(prompts, sampling_params, use_tqdm=False) - cold_time = time.time() - start_time + events_endpoint = events_endpoint.replace("*", "127.0.0.1") + subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic) - # run generation again - should hit the GPU prefix cache - start_time = time.time() - llm.generate(prompts, sampling_params, use_tqdm=False) - gpu_hit_time = time.time() - start_time + try: + num_times_cpu_better_than_cold = 0 + num_tests = 10 + total_cold_time = 0.0 + total_gpu_hit_time = 0.0 + total_cpu_hit_time = 0.0 + prompt_token_ids = [0] * 10001 + for i in tqdm(range(num_tests), desc="Running tests"): + prompt_token_ids[0] = i + prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)] - # reset prefix cache to avoid GPU hit. - llm.reset_prefix_cache() + # run generation - this should trigger saving KV cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cold_time = time.time() - start_time + total_cold_time += cold_time - # sleep for a sec to make sure CPU finished storing - time.sleep(1) + # run generation again - should hit the GPU prefix cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + gpu_hit_time = time.time() - start_time + total_gpu_hit_time += gpu_hit_time - # run generation again - this should trigger loading from CPU - start_time = time.time() - llm.generate(prompts, sampling_params, use_tqdm=False) - cpu_hit_time = time.time() - start_time + # reset prefix cache to avoid GPU hit. + llm.reset_prefix_cache() - print("Generation times:") - print(f" Cold: {cold_time * 1000:.2f}ms") - print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms") - print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms") + assert subscriber.get_new_cpu_stored_events() + + # run generation again - this should trigger loading from CPU + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cpu_hit_time = time.time() - start_time + total_cpu_hit_time += cpu_hit_time + + if cpu_hit_time < cold_time: + num_times_cpu_better_than_cold += 1 + + print("Average times:") + print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms") + print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms") + print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms") + + assert num_times_cpu_better_than_cold >= 0.8 * num_tests + finally: + subscriber.close() + del llm diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 14bdf295317d7..00b34fe4fbb98 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -646,23 +646,6 @@ class Scheduler(SchedulerInterface): meta = self.connector.build_connector_meta(scheduler_output) scheduler_output.kv_connector_metadata = meta - # collect KV cache events from KV cache manager - events = self.kv_cache_manager.take_events() - - # collect KV cache events from connector - if self.connector is not None: - connector_events = self.connector.take_events() - if connector_events: - if events is None: - events = list(connector_events) - else: - events.extend(connector_events) - - # publish collected KV cache events - if events: - batch = KVEventBatch(ts=time.time(), events=events) - self.kv_event_publisher.publish(batch) - self._update_after_schedule(scheduler_output) return scheduler_output @@ -1057,6 +1040,23 @@ class Scheduler(SchedulerInterface): if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + # Create EngineCoreOutputs for all clients that have requests with # outputs in this step. engine_core_outputs = {