[Metrics] Log multi-modal cache stats and fix reset (#26285)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-10 16:45:55 +08:00 committed by GitHub
parent 6f0f570c43
commit ad430a67ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 586 additions and 235 deletions

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.v1.metrics.reader import Counter, Metric
from ..openai.test_vision import TEST_IMAGE_ASSETS
def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]:
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_url},
},
],
}
]
def _get_counter_value(metrics: list[Metric], name: str):
metric = next(m for m in metrics if m.name == name)
assert isinstance(metric, Counter)
return metric.value
def _get_mm_cache_stats(metrics: list[Metric]):
mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries")
mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits")
return mm_cache_queries, mm_cache_hits
@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True)
@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"])
def test_mm_cache_stats(
num_gpus_available,
image_urls,
mm_processor_cache_type,
):
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
mm_processor_cache_type=mm_processor_cache_type,
disable_log_stats=False,
limit_mm_per_prompt={"image": 2},
)
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0)
llm.chat(_make_messages(image_urls[1]))
assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0)
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1)
# NOTE: This only resets hit rate stats in CachingMetrics
# The raw queries and hits counts remain unaffected
llm.reset_mm_cache()
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1)
llm.chat(_make_messages(image_urls[1]))
assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1)

View File

@ -18,10 +18,18 @@ from vllm import version
from ...utils import RemoteOpenAIServer
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODELS = {
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",
}
PREV_MINOR_VERSION = version._prev_minor_version()
@pytest.fixture(scope="module", params=list(MODELS.keys()))
def model_key(request):
yield request.param
@pytest.fixture(scope="module")
def default_server_args():
return [
@ -45,11 +53,12 @@ def default_server_args():
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
],
)
def server(default_server_args, request):
def server(model_key, default_server_args, request):
if request.param:
default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
model_name = MODELS[model_key]
with RemoteOpenAIServer(model_name, default_server_args) as remote_server:
yield remote_server
@ -60,64 +69,70 @@ async def client(server):
_PROMPT = "Hello my name is Robert and I love magic"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
_NUM_REQUESTS = 10
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
# {metric_family: [(suffix, expected_value)]}
EXPECTED_VALUES = {
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
"vllm:time_per_output_token_seconds": [
("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))
],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
"vllm:request_prompt_tokens": [
("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS),
],
"vllm:request_generation_tokens": [
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS),
],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_max_tokens": [
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS),
],
"vllm:iteration_tokens_total": [
(
"_sum",
_NUM_REQUESTS
* (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST),
),
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
],
"vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens": [
("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)
],
"vllm:request_success": [("_total", _NUM_REQUESTS)],
}
def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int):
num_prompt_tokens = len(prompt_ids)
# {metric_family: [(suffix, expected_value)]}
return {
"vllm:time_to_first_token_seconds": [("_count", num_requests)],
"vllm:time_per_output_token_seconds": [
("_count", num_requests * (max_tokens - 1))
],
"vllm:e2e_request_latency_seconds": [("_count", num_requests)],
"vllm:request_queue_time_seconds": [("_count", num_requests)],
"vllm:request_inference_time_seconds": [("_count", num_requests)],
"vllm:request_prefill_time_seconds": [("_count", num_requests)],
"vllm:request_decode_time_seconds": [("_count", num_requests)],
"vllm:request_prompt_tokens": [
("_sum", num_requests * num_prompt_tokens),
("_count", num_requests),
],
"vllm:request_generation_tokens": [
("_sum", num_requests * max_tokens),
("_count", num_requests),
],
"vllm:request_params_n": [("_count", num_requests)],
"vllm:request_params_max_tokens": [
("_sum", num_requests * max_tokens),
("_count", num_requests),
],
"vllm:iteration_tokens_total": [
(
"_sum",
num_requests * (num_prompt_tokens + max_tokens),
),
("_count", num_requests * max_tokens),
],
"vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)],
"vllm:generation_tokens": [("_total", num_requests * max_tokens)],
"vllm:request_success": [("_total", num_requests)],
}
@pytest.mark.asyncio
async def test_metrics_counts(
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
):
for _ in range(_NUM_REQUESTS):
if model_key == "multimodal":
pytest.skip("Unnecessary test")
model_name = MODELS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_ids = tokenizer.encode(_PROMPT)
num_requests = 10
max_tokens = 10
for _ in range(num_requests):
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST,
model=model_name,
prompt=prompt_ids,
max_tokens=max_tokens,
)
response = requests.get(server.url_for("metrics"))
@ -125,8 +140,9 @@ async def test_metrics_counts(
assert response.status_code == HTTPStatus.OK
# Loop over all expected metric_families
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
if (metric_family not in EXPECTED_METRICS_V1) or (
expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens)
for metric_family, suffix_values_list in expected_values.items():
if metric_family not in EXPECTED_METRICS_V1 or (
not server.show_hidden_metrics
and metric_family in HIDDEN_DEPRECATED_METRICS
):
@ -217,6 +233,11 @@ EXPECTED_METRICS_V1 = [
"vllm:request_decode_time_seconds_count",
]
EXPECTED_METRICS_MM = [
"vllm:mm_cache_queries",
"vllm:mm_cache_hits",
]
HIDDEN_DEPRECATED_METRICS: list[str] = [
"vllm:gpu_cache_usage_perc",
"vllm:gpu_prefix_cache_queries",
@ -231,19 +252,43 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
async def test_metrics_exist(
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
):
model_name = MODELS[model_key]
# sending a request triggers the metrics to be logged.
await client.completions.create(
model=MODEL_NAME,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0,
)
if model_key == "text":
await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0,
)
else:
await client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": _IMAGE_URL}},
{"type": "text", "text": "What's in this image?"},
],
}
],
max_tokens=5,
temperature=0.0,
)
response = requests.get(server.url_for("metrics"))
assert response.status_code == HTTPStatus.OK
for metric in EXPECTED_METRICS_V1:
expected_metrics = EXPECTED_METRICS_V1
if model_key == "multimodal":
# NOTE: Don't use in-place assignment
expected_metrics = expected_metrics + EXPECTED_METRICS_MM
for metric in expected_metrics:
if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics:
continue
assert metric in response.text
@ -253,9 +298,14 @@ async def test_metrics_exist(
async def test_abort_metrics_reset(
server: RemoteOpenAIServer,
client: openai.AsyncClient,
model_key: str,
):
model_name = MODELS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_ids = tokenizer.encode(_PROMPT)
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
server
server,
)
# Expect no running requests or kvcache usage
@ -268,8 +318,8 @@ async def test_abort_metrics_reset(
for _ in range(3):
task = asyncio.create_task(
client.completions.create(
model=MODEL_NAME,
prompt=_TOKENIZED_PROMPT,
model=model_name,
prompt=prompt_ids,
max_tokens=100, # Long generation to give time to abort
temperature=0.0,
)
@ -281,7 +331,7 @@ async def test_abort_metrics_reset(
# Check that we have running requests
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
server
server,
)
# Expect running requests and kvcache usage

View File

@ -20,7 +20,6 @@ from vllm.v1.core.kv_cache_utils import (
BlockHash,
FreeKVCacheBlockQueue,
KVCacheBlock,
PrefixCachingMetrics,
estimate_max_model_len,
generate_block_hash_extra_keys,
generate_scheduler_kv_cache_config,
@ -42,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
from vllm.v1.request import Request
pytestmark = pytest.mark.cpu_test
@ -536,7 +535,7 @@ def test_metrics():
"""
Test the prefix caching metrics.
"""
metrics = PrefixCachingMetrics(max_recent_requests=5)
metrics = CachingMetrics(max_recent_requests=5)
assert metrics.hit_rate == 0.0
metrics.observe(_stats(1, 20, 9))
@ -568,7 +567,7 @@ def test_metrics_empty_stats():
"""
Test the prefix caching metrics with empty stats.
"""
metrics = PrefixCachingMetrics(max_recent_requests=5)
metrics = CachingMetrics(max_recent_requests=5)
metrics.observe(_stats(0, 0, 0))
metrics.observe(_stats(1, 20, 9))
metrics.observe(_stats(0, 0, 0))

View File

@ -17,7 +17,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2))
@ -93,6 +93,7 @@ async def test_load(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
):
if iteration_stats:

View File

@ -354,6 +354,10 @@ class LLM:
else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def reset_mm_cache(self) -> None:
self.processor.clear_mm_cache()
self.llm_engine.reset_mm_cache()
def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None:
self.default_sampling_params = self.model_config.get_diff_sampling_param()

View File

@ -274,6 +274,10 @@ class OpenAIServing:
self.model_config = self.models.model_config
self.max_model_len = self.model_config.max_model_len
async def reset_mm_cache(self) -> None:
self.processor.clear_mm_cache()
await self.engine_client.reset_mm_cache()
async def beam_search(
self,
prompt: PromptType,

View File

@ -169,6 +169,10 @@ class ExecutorBase(ABC):
assert s == sets[0], "All workers should have the same LORAs."
return sets[0]
def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def start_profile(self) -> None:
self.collective_rpc("start_profile")

View File

@ -12,11 +12,8 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -30,16 +27,13 @@ class UniProcExecutor(ExecutorBase):
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args()
is_driver_worker = True
kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock()
is_driver_worker=True,
shared_worker_lock=Lock(),
)
self.async_output_thread: Optional[ThreadPoolExecutor] = None
@ -74,8 +68,6 @@ class UniProcExecutor(ExecutorBase):
) -> list[Any]:
if kwargs is None:
kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args)
if not non_block:
return [run_method(self.driver_worker, method, args, kwargs)]

View File

@ -19,6 +19,7 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import (
DecoderOnlyInputs,
@ -56,6 +57,8 @@ class InputPreprocessor:
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None:
raise ValueError(
@ -664,14 +667,13 @@ class InputPreprocessor:
return self._build_decoder_only_llm_inputs(prompt_comps)
def preprocess(
def _preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
@ -694,6 +696,40 @@ class InputPreprocessor:
mm_uuids=mm_uuids,
)
def clear_cache(self) -> None:
def preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(
prompt,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
if self.mm_processor_cache and self.mm_cache_stats is not None:
delta = self.mm_processor_cache.make_stats(delta=True)
self.mm_cache_stats.requests += 1
self.mm_cache_stats.queries += delta.total
self.mm_cache_stats.hits += delta.hits
return res
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
mm_cache_stats = self.mm_cache_stats
if mm_cache_stats is None:
return None
self.mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def clear_mm_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
if self.mm_cache_stats is not None:
self.mm_cache_stats.reset = True

View File

@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.shm_object_storage import (
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, MiB_bytes
from vllm.utils.cache import LRUCache
from vllm.utils.cache import CacheInfo, LRUCache
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
from .inputs import (
@ -302,6 +302,16 @@ class BaseMultiModalProcessorCache(
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
@abstractmethod
def make_stats(self, *, delta: bool = False) -> CacheInfo:
"""
Get (and reset) the multi-modal cache stats.
Returns:
The current multi-modal caching stats.
"""
raise NotImplementedError
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
@ -347,6 +357,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
@ -397,6 +411,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
def clear_cache(self) -> None:
self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
"""
@ -430,6 +448,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
# cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def _stat(self, *, delta: bool = False) -> CacheInfo:
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
@override
def is_cached_item(self, mm_hash: str) -> bool:
return self._shm_cache.is_cached(mm_hash)
@ -441,12 +473,17 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if self._shm_cache.is_cached(mm_hash):
self._hits += 1
self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._total += 1
try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
# Try to remove dangling items if p0 cache is too large.
@ -469,6 +506,14 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self._shm_cache.clear()
self._p0_cache.clear()
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta)
def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys()

View File

@ -4,7 +4,7 @@
import copy
import os
from collections import defaultdict, deque
from collections import defaultdict
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NewType, Optional, Union
@ -23,7 +23,6 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
# BlockHash represents the hash of a single KV-cache block used for
@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]):
NONE_HASH = BlockHash(hash_fn(hash_seed))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the max recent N requests.
Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000):
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if stats.requests == 0:
return
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while (
len(self.query_queue) > 1
and self.aggregated_requests > self.max_recent_requests
):
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""

View File

@ -463,6 +463,7 @@ class AsyncLLM(EngineClient):
output_processor = self.output_processor
log_stats = self.log_stats
logger_manager = self.logger_manager
processor = self.processor
async def output_handler():
try:
@ -511,6 +512,7 @@ class AsyncLLM(EngineClient):
engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=processor.stat_mm_cache(),
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
@ -660,7 +662,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None:
self.processor.clear_cache()
self.processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:

View File

@ -319,7 +319,7 @@ class EngineCore:
)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
) # type: ignore
)
return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
@ -400,16 +400,19 @@ class EngineCore:
def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
# re-sync the internal caches (P0 sender, P1 receiver)
if self.scheduler.has_unfinished_requests():
logger.warning(
"Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches."
)
# The cache either exists in EngineCore or WorkerWrapperBase
if self.mm_receiver_cache is not None:
self.mm_receiver_cache.clear_cache()
self.model_executor.reset_mm_cache()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

View File

@ -306,9 +306,11 @@ class LLMEngine:
# 4) Record stats
if self.logger_manager is not None:
assert outputs.scheduler_stats is not None
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
@ -321,7 +323,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.clear_cache()
self.processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):

View File

@ -21,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer,
@ -573,5 +574,8 @@ class Processor:
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None:
self.input_preprocessor.clear_cache()
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
return self.input_preprocessor.stat_mm_cache()
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()

View File

@ -33,8 +33,6 @@ from vllm.distributed.parallel_state import (
get_tp_group,
)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (
_maybe_force_spawn,
decorate_logs,
@ -46,7 +44,6 @@ from vllm.utils import (
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -422,6 +419,7 @@ class WorkerProc:
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker,
"shared_worker_lock": shared_worker_lock,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper
@ -445,11 +443,6 @@ class WorkerProc:
)
self.async_output_copy_thread.start()
# Initialize multimodal receiver cache if needed
self.mm_receiver_cache = worker_receiver_cache_from_config(
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock
)
# Initialize device
self.worker.init_device()
@ -692,12 +685,7 @@ class WorkerProc:
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
# retrieve from shm cache if available
if (
self.mm_receiver_cache is not None
and func.__name__ == "execute_model"
):
get_and_update_mm_cache(self.mm_receiver_cache, args)
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11

View File

@ -1,24 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.multimodal.cache import ShmObjectStoreReceiverCache
from vllm.v1.core.sched.output import SchedulerOutput
def get_and_update_mm_cache(
receiver_cache: ShmObjectStoreReceiverCache,
args: tuple[SchedulerOutput],
) -> None:
"""
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
cache as needed.
Args:
receiver_cache: The receiver cache to update.
args: According to the collective_rpc call of execute_model method in
executor, args is a tuple of only one SchedulerOutput element.
"""
scheduler_output = args[0]
for request_data in scheduler_output.scheduled_new_reqs:
request_data.mm_features = receiver_cache.get_and_update_features(
request_data.mm_features
)

View File

@ -11,10 +11,14 @@ import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.metrics.stats import (
CachingMetrics,
IterationStats,
MultiModalCacheStats,
SchedulerStats,
)
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
@ -38,6 +42,7 @@ class StatLoggerBase(ABC):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
): ...
@ -53,10 +58,15 @@ class LoggingStatLogger(StatLoggerBase):
self.engine_index = engine_index
self.vllm_config = vllm_config
self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset.
self.last_mm_cache_stats: Optional[MultiModalCacheStats] = None
# Caching metrics. This cannot be reset.
# TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics()
self.prefix_caching_metrics = CachingMetrics()
self.mm_caching_metrics = CachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
@ -86,6 +96,7 @@ class LoggingStatLogger(StatLoggerBase):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
):
"""Log Stats to standard output."""
@ -101,6 +112,11 @@ class LoggingStatLogger(StatLoggerBase):
self.kv_connector_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats
if mm_cache_stats:
self.mm_caching_metrics.observe(mm_cache_stats)
self.last_mm_cache_stats = mm_cache_stats
def log(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
@ -125,21 +141,32 @@ class LoggingStatLogger(StatLoggerBase):
self.last_prompt_throughput = prompt_throughput
# Format and print output.
log_fn(
"Engine %03d: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
log_parts = [
"Avg prompt throughput: %.1f tokens/s",
"Avg generation throughput: %.1f tokens/s",
"Running: %d reqs",
"Waiting: %d reqs",
"GPU KV cache usage: %.1f%%",
"Prefix cache hit rate: %.1f%%",
self.engine_index,
]
log_args = [
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
]
if self.last_mm_cache_stats:
log_parts.append("MM cache hit rate: %.1f%%")
log_args.append(self.mm_caching_metrics.hit_rate * 100)
log_fn(
"Engine %03d: " + ", ".join(log_parts),
self.engine_index,
*log_args,
)
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_connector_logging.log(log_fn=log_fn)
@ -288,6 +315,32 @@ class PrometheusStatLogger(StatLoggerBase):
counter_prefix_cache_hits, engine_indexes, model_name
)
#
# Multi-modal cache
#
counter_mm_cache_queries = self._counter_cls(
name="vllm:mm_cache_queries",
documentation=(
"Multi-modal cache queries, in terms of number of queried items."
),
labelnames=labelnames,
)
self.counter_mm_cache_queries = make_per_engine(
counter_mm_cache_queries, engine_indexes, model_name
)
counter_mm_cache_hits = self._counter_cls(
name="vllm:mm_cache_hits",
documentation=(
"Multi-modal cache hits, in terms of number of cached items."
),
labelnames=labelnames,
)
self.counter_mm_cache_hits = make_per_engine(
counter_mm_cache_hits, engine_indexes, model_name
)
#
# Counters
#
@ -657,6 +710,7 @@ class PrometheusStatLogger(StatLoggerBase):
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0,
):
"""Log to prometheus."""
@ -694,6 +748,10 @@ class PrometheusStatLogger(StatLoggerBase):
scheduler_stats.spec_decoding_stats, engine_idx
)
if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
if iteration_stats is None:
return
@ -871,6 +929,7 @@ class StatLoggerManager:
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: Optional[int] = None,
):
if engine_idx is None:
@ -878,9 +937,19 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections import deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
@ -13,24 +14,122 @@ if TYPE_CHECKING:
@dataclass
class PrefixCacheStats:
"""Stores prefix cache hit statistics."""
class BaseCacheStats:
"""Stores cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False
# The number of new requests in this update.
"""Whether the cache was reset."""
requests: int = 0
# The number of queries in these requests. Note that "queries" here
# means the number of tokens that were queried from the cache.
"""The number of requests in this update."""
queries: int = 0
# The number of hits in these requests.
"""The number of queries in these requests."""
hits: int = 0
# The number of previously preempted requests in this update.
"""The number of hits in these requests."""
class CachingMetrics:
"""Metrics for caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000) -> None:
super().__init__()
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue = deque[tuple[int, int, int]]()
def observe(self, stats: BaseCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if stats.requests == 0:
return
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while (
len(self.query_queue) > 1
and self.aggregated_requests > self.max_recent_requests
):
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class PrefixCacheStats(BaseCacheStats):
"""
Stores prefix cache hit statistics.
- `reset`: Whether `reset_prefix_cache` was invoked.
- `queries`: Refers to the number of tokens that were queried.
"""
preempted_requests: int = 0
# The `queries` number for preempted requests.
"""The number of previously preempted requests in this update."""
preempted_queries: int = 0
# The `hits` number for preempted requests.
"""The `queries` number for preempted requests."""
preempted_hits: int = 0
"""The `hits` number for preempted requests."""
@dataclass
class MultiModalCacheStats(BaseCacheStats):
"""
Stores multi-modal cache hit statistics.
- `reset`: Whether `reset_mm_cache` was invoked.
- `queries`: Refers to the number of multi-modal data items
that were queried.
"""
@dataclass

View File

@ -508,6 +508,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
)
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int):
if self.uses_mrope:

View File

@ -442,6 +442,9 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()

View File

@ -371,6 +371,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.sample_from_logits_func = self.sample_from_logits
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:

View File

@ -293,6 +293,9 @@ class TPUWorker:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()

View File

@ -126,6 +126,10 @@ class MultiModalBudget:
return max_items_per_prompt, max_items_per_batch
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()
@dataclass
class AttentionGroup:

View File

@ -4,7 +4,7 @@
from __future__ import annotations
import os
from typing import Any, Callable, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
import torch
import torch.nn as nn
@ -12,7 +12,8 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import (
enable_trace_function_call_for_thread,
resolve_obj_by_qualname,
@ -21,7 +22,10 @@ from vllm.utils import (
warn_for_unimplemented_methods,
)
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.outputs import SamplerOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
@ -103,6 +107,11 @@ class WorkerBase:
"""Initialize the KV cache with the given size in blocks."""
raise NotImplementedError
def reset_mm_cache(self) -> None:
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
if callable(reset_fn):
reset_fn()
def get_model(self) -> nn.Module:
raise NotImplementedError
@ -114,9 +123,7 @@ class WorkerBase:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self, execute_model_req: ExecuteModelRequest | None = None
) -> list[SamplerOutput] | None:
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
raise NotImplementedError
def start_worker_execution_loop(self) -> None:
@ -125,11 +132,7 @@ class WorkerBase:
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
with self.current_platform.inference_mode():
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
raise NotImplementedError("Dead V0 code")
def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
@ -289,6 +292,28 @@ class WorkerWrapperBase:
worker_class,
extended_calls,
)
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
if shared_worker_lock is None:
msg = (
"Missing `shared_worker_lock` argument from executor. "
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
logger.warning_once(msg)
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
)
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
@ -323,5 +348,34 @@ class WorkerWrapperBase:
logger.exception(msg)
raise e
def __getattr__(self, attr):
def __getattr__(self, attr: str):
return getattr(self.worker, attr)
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
mm_cache = self.mm_receiver_cache
if mm_cache is None:
return
for req_data in scheduler_output.scheduled_new_reqs:
req_data.mm_features = mm_cache.get_and_update_features(
req_data.mm_features
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None
return self.worker.execute_model(scheduler_output, *args, **kwargs)
def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()
assert self.worker is not None
self.worker.reset_mm_cache()