mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 18:34:06 +08:00
[Core] Scheduler: Publish connector events after output (#25875)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
parent
6afc28a9ba
commit
111faf1118
@ -1,15 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import socket
|
||||
import time
|
||||
|
||||
import msgspec
|
||||
import msgspec.msgpack
|
||||
import pytest
|
||||
import zmq
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import KVEventsConfig, KVTransferConfig
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
|
||||
CPU_BLOCK_SIZES = [16, 48]
|
||||
|
||||
|
||||
class MockSubscriber:
|
||||
"""Helper class to receive and verify published events"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
topic: str,
|
||||
):
|
||||
self.ctx = zmq.Context.instance()
|
||||
self.topic_bytes = topic.encode("utf-8")
|
||||
|
||||
# Set up subscriber socket
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, self.topic_bytes)
|
||||
self.sub.connect(endpoint)
|
||||
|
||||
self.decoder = msgspec.msgpack.Decoder(type=KVEventBatch)
|
||||
|
||||
def get_new_cpu_stored_events(self) -> list[BlockStored]:
|
||||
cpu_stored_events: list[BlockStored] = []
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(self.sub, zmq.POLLIN)
|
||||
|
||||
timeout = 1000 # 1 second
|
||||
while True:
|
||||
events = dict(poller.poll(timeout))
|
||||
|
||||
if events.get(self.sub) != zmq.POLLIN:
|
||||
return cpu_stored_events
|
||||
|
||||
topic_bytes, _, payload = self.sub.recv_multipart()
|
||||
|
||||
assert topic_bytes == self.topic_bytes
|
||||
|
||||
event_batch = self.decoder.decode(payload)
|
||||
assert isinstance(event_batch, KVEventBatch)
|
||||
for event in event_batch.events:
|
||||
if isinstance(event, BlockStored) and event.medium == "CPU":
|
||||
cpu_stored_events.append(event)
|
||||
timeout = 100
|
||||
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
self.sub.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES)
|
||||
def test_cpu_offloading(cpu_block_size: int) -> None:
|
||||
"""
|
||||
@ -20,41 +73,80 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"num_cpu_blocks": 100, "block_size": cpu_block_size},
|
||||
kv_connector_extra_config={
|
||||
"num_cpu_blocks": 1000,
|
||||
"block_size": cpu_block_size,
|
||||
},
|
||||
)
|
||||
|
||||
port: int
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("0.0.0.0", 0))
|
||||
port = s.getsockname()[1]
|
||||
|
||||
events_endpoint = f"tcp://*:{port}"
|
||||
kv_events_config = KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=events_endpoint,
|
||||
topic="test",
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_events_config=kv_events_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
disable_hybrid_kv_cache_manager=True,
|
||||
)
|
||||
|
||||
prompts = ["Hi " * 100]
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=20)
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=1)
|
||||
|
||||
# run generation - this should trigger saving KV cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start_time
|
||||
events_endpoint = events_endpoint.replace("*", "127.0.0.1")
|
||||
subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic)
|
||||
|
||||
# run generation again - should hit the GPU prefix cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
gpu_hit_time = time.time() - start_time
|
||||
try:
|
||||
num_times_cpu_better_than_cold = 0
|
||||
num_tests = 10
|
||||
total_cold_time = 0.0
|
||||
total_gpu_hit_time = 0.0
|
||||
total_cpu_hit_time = 0.0
|
||||
prompt_token_ids = [0] * 10001
|
||||
for i in tqdm(range(num_tests), desc="Running tests"):
|
||||
prompt_token_ids[0] = i
|
||||
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
|
||||
|
||||
# reset prefix cache to avoid GPU hit.
|
||||
llm.reset_prefix_cache()
|
||||
# run generation - this should trigger saving KV cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cold_time = time.time() - start_time
|
||||
total_cold_time += cold_time
|
||||
|
||||
# sleep for a sec to make sure CPU finished storing
|
||||
time.sleep(1)
|
||||
# run generation again - should hit the GPU prefix cache
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
gpu_hit_time = time.time() - start_time
|
||||
total_gpu_hit_time += gpu_hit_time
|
||||
|
||||
# run generation again - this should trigger loading from CPU
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_hit_time = time.time() - start_time
|
||||
# reset prefix cache to avoid GPU hit.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
print("Generation times:")
|
||||
print(f" Cold: {cold_time * 1000:.2f}ms")
|
||||
print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms")
|
||||
print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms")
|
||||
assert subscriber.get_new_cpu_stored_events()
|
||||
|
||||
# run generation again - this should trigger loading from CPU
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
cpu_hit_time = time.time() - start_time
|
||||
total_cpu_hit_time += cpu_hit_time
|
||||
|
||||
if cpu_hit_time < cold_time:
|
||||
num_times_cpu_better_than_cold += 1
|
||||
|
||||
print("Average times:")
|
||||
print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms")
|
||||
print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms")
|
||||
print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms")
|
||||
|
||||
assert num_times_cpu_better_than_cold >= 0.8 * num_tests
|
||||
finally:
|
||||
subscriber.close()
|
||||
del llm
|
||||
|
||||
@ -646,23 +646,6 @@ class Scheduler(SchedulerInterface):
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# collect KV cache events from KV cache manager
|
||||
events = self.kv_cache_manager.take_events()
|
||||
|
||||
# collect KV cache events from connector
|
||||
if self.connector is not None:
|
||||
connector_events = self.connector.take_events()
|
||||
if connector_events:
|
||||
if events is None:
|
||||
events = list(connector_events)
|
||||
else:
|
||||
events.extend(connector_events)
|
||||
|
||||
# publish collected KV cache events
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
@ -1057,6 +1040,23 @@ class Scheduler(SchedulerInterface):
|
||||
if kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(kv_connector_output)
|
||||
|
||||
# collect KV cache events from KV cache manager
|
||||
events = self.kv_cache_manager.take_events()
|
||||
|
||||
# collect KV cache events from connector
|
||||
if self.connector is not None:
|
||||
connector_events = self.connector.take_events()
|
||||
if connector_events:
|
||||
if events is None:
|
||||
events = list(connector_events)
|
||||
else:
|
||||
events.extend(connector_events)
|
||||
|
||||
# publish collected KV cache events
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
# Create EngineCoreOutputs for all clients that have requests with
|
||||
# outputs in this step.
|
||||
engine_core_outputs = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user