[Metrics] Log multi-modal cache stats and fix reset (#26285)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-10 16:45:55 +08:00 committed by GitHub
parent 6f0f570c43
commit ad430a67ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 586 additions and 235 deletions

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm import LLM
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.v1.metrics.reader import Counter, Metric
from ..openai.test_vision import TEST_IMAGE_ASSETS
def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]:
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": image_url},
},
],
}
]
def _get_counter_value(metrics: list[Metric], name: str):
metric = next(m for m in metrics if m.name == name)
assert isinstance(metric, Counter)
return metric.value
def _get_mm_cache_stats(metrics: list[Metric]):
mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries")
mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits")
return mm_cache_queries, mm_cache_hits
@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True)
@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"])
def test_mm_cache_stats(
num_gpus_available,
image_urls,
mm_processor_cache_type,
):
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
mm_processor_cache_type=mm_processor_cache_type,
disable_log_stats=False,
limit_mm_per_prompt={"image": 2},
)
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0)
llm.chat(_make_messages(image_urls[1]))
assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0)
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1)
# NOTE: This only resets hit rate stats in CachingMetrics
# The raw queries and hits counts remain unaffected
llm.reset_mm_cache()
llm.chat(_make_messages(image_urls[0]))
assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1)
llm.chat(_make_messages(image_urls[1]))
assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1)

View File

@ -18,10 +18,18 @@ from vllm import version
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" MODELS = {
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",
}
PREV_MINOR_VERSION = version._prev_minor_version() PREV_MINOR_VERSION = version._prev_minor_version()
@pytest.fixture(scope="module", params=list(MODELS.keys()))
def model_key(request):
yield request.param
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(): def default_server_args():
return [ return [
@ -45,11 +53,12 @@ def default_server_args():
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}", f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
], ],
) )
def server(default_server_args, request): def server(model_key, default_server_args, request):
if request.param: if request.param:
default_server_args.append(request.param) default_server_args.append(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: model_name = MODELS[model_key]
with RemoteOpenAIServer(model_name, default_server_args) as remote_server:
yield remote_server yield remote_server
@ -60,64 +69,70 @@ async def client(server):
_PROMPT = "Hello my name is Robert and I love magic" _PROMPT = "Hello my name is Robert and I love magic"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
_NUM_REQUESTS = 10
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
# {metric_family: [(suffix, expected_value)]} def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int):
EXPECTED_VALUES = { num_prompt_tokens = len(prompt_ids)
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
"vllm:time_per_output_token_seconds": [ # {metric_family: [(suffix, expected_value)]}
("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1)) return {
], "vllm:time_to_first_token_seconds": [("_count", num_requests)],
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)], "vllm:time_per_output_token_seconds": [
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)], ("_count", num_requests * (max_tokens - 1))
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)], ],
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:e2e_request_latency_seconds": [("_count", num_requests)],
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)], "vllm:request_queue_time_seconds": [("_count", num_requests)],
"vllm:request_prompt_tokens": [ "vllm:request_inference_time_seconds": [("_count", num_requests)],
("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST), "vllm:request_prefill_time_seconds": [("_count", num_requests)],
("_count", _NUM_REQUESTS), "vllm:request_decode_time_seconds": [("_count", num_requests)],
], "vllm:request_prompt_tokens": [
"vllm:request_generation_tokens": [ ("_sum", num_requests * num_prompt_tokens),
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), ("_count", num_requests),
("_count", _NUM_REQUESTS), ],
], "vllm:request_generation_tokens": [
"vllm:request_params_n": [("_count", _NUM_REQUESTS)], ("_sum", num_requests * max_tokens),
"vllm:request_params_max_tokens": [ ("_count", num_requests),
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), ],
("_count", _NUM_REQUESTS), "vllm:request_params_n": [("_count", num_requests)],
], "vllm:request_params_max_tokens": [
"vllm:iteration_tokens_total": [ ("_sum", num_requests * max_tokens),
( ("_count", num_requests),
"_sum", ],
_NUM_REQUESTS "vllm:iteration_tokens_total": [
* (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST), (
), "_sum",
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST), num_requests * (num_prompt_tokens + max_tokens),
], ),
"vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)], ("_count", num_requests * max_tokens),
"vllm:generation_tokens": [ ],
("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST) "vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)],
], "vllm:generation_tokens": [("_total", num_requests * max_tokens)],
"vllm:request_success": [("_total", _NUM_REQUESTS)], "vllm:request_success": [("_total", num_requests)],
} }
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metrics_counts( async def test_metrics_counts(
server: RemoteOpenAIServer, server: RemoteOpenAIServer,
client: openai.AsyncClient, client: openai.AsyncClient,
model_key: str,
): ):
for _ in range(_NUM_REQUESTS): if model_key == "multimodal":
pytest.skip("Unnecessary test")
model_name = MODELS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_ids = tokenizer.encode(_PROMPT)
num_requests = 10
max_tokens = 10
for _ in range(num_requests):
# sending a request triggers the metrics to be logged. # sending a request triggers the metrics to be logged.
await client.completions.create( await client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=_TOKENIZED_PROMPT, prompt=prompt_ids,
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST, max_tokens=max_tokens,
) )
response = requests.get(server.url_for("metrics")) response = requests.get(server.url_for("metrics"))
@ -125,8 +140,9 @@ async def test_metrics_counts(
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
# Loop over all expected metric_families # Loop over all expected metric_families
for metric_family, suffix_values_list in EXPECTED_VALUES.items(): expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens)
if (metric_family not in EXPECTED_METRICS_V1) or ( for metric_family, suffix_values_list in expected_values.items():
if metric_family not in EXPECTED_METRICS_V1 or (
not server.show_hidden_metrics not server.show_hidden_metrics
and metric_family in HIDDEN_DEPRECATED_METRICS and metric_family in HIDDEN_DEPRECATED_METRICS
): ):
@ -217,6 +233,11 @@ EXPECTED_METRICS_V1 = [
"vllm:request_decode_time_seconds_count", "vllm:request_decode_time_seconds_count",
] ]
EXPECTED_METRICS_MM = [
"vllm:mm_cache_queries",
"vllm:mm_cache_hits",
]
HIDDEN_DEPRECATED_METRICS: list[str] = [ HIDDEN_DEPRECATED_METRICS: list[str] = [
"vllm:gpu_cache_usage_perc", "vllm:gpu_cache_usage_perc",
"vllm:gpu_prefix_cache_queries", "vllm:gpu_prefix_cache_queries",
@ -231,19 +252,43 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
async def test_metrics_exist( async def test_metrics_exist(
server: RemoteOpenAIServer, server: RemoteOpenAIServer,
client: openai.AsyncClient, client: openai.AsyncClient,
model_key: str,
): ):
model_name = MODELS[model_key]
# sending a request triggers the metrics to be logged. # sending a request triggers the metrics to be logged.
await client.completions.create( if model_key == "text":
model=MODEL_NAME, await client.completions.create(
prompt="Hello, my name is", model=model_name,
max_tokens=5, prompt="Hello, my name is",
temperature=0.0, max_tokens=5,
) temperature=0.0,
)
else:
await client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": _IMAGE_URL}},
{"type": "text", "text": "What's in this image?"},
],
}
],
max_tokens=5,
temperature=0.0,
)
response = requests.get(server.url_for("metrics")) response = requests.get(server.url_for("metrics"))
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
for metric in EXPECTED_METRICS_V1: expected_metrics = EXPECTED_METRICS_V1
if model_key == "multimodal":
# NOTE: Don't use in-place assignment
expected_metrics = expected_metrics + EXPECTED_METRICS_MM
for metric in expected_metrics:
if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics: if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics:
continue continue
assert metric in response.text assert metric in response.text
@ -253,9 +298,14 @@ async def test_metrics_exist(
async def test_abort_metrics_reset( async def test_abort_metrics_reset(
server: RemoteOpenAIServer, server: RemoteOpenAIServer,
client: openai.AsyncClient, client: openai.AsyncClient,
model_key: str,
): ):
model_name = MODELS[model_key]
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt_ids = tokenizer.encode(_PROMPT)
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
server server,
) )
# Expect no running requests or kvcache usage # Expect no running requests or kvcache usage
@ -268,8 +318,8 @@ async def test_abort_metrics_reset(
for _ in range(3): for _ in range(3):
task = asyncio.create_task( task = asyncio.create_task(
client.completions.create( client.completions.create(
model=MODEL_NAME, model=model_name,
prompt=_TOKENIZED_PROMPT, prompt=prompt_ids,
max_tokens=100, # Long generation to give time to abort max_tokens=100, # Long generation to give time to abort
temperature=0.0, temperature=0.0,
) )
@ -281,7 +331,7 @@ async def test_abort_metrics_reset(
# Check that we have running requests # Check that we have running requests
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api( running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
server server,
) )
# Expect running requests and kvcache usage # Expect running requests and kvcache usage

View File

@ -20,7 +20,6 @@ from vllm.v1.core.kv_cache_utils import (
BlockHash, BlockHash,
FreeKVCacheBlockQueue, FreeKVCacheBlockQueue,
KVCacheBlock, KVCacheBlock,
PrefixCachingMetrics,
estimate_max_model_len, estimate_max_model_len,
generate_block_hash_extra_keys, generate_block_hash_extra_keys,
generate_scheduler_kv_cache_config, generate_scheduler_kv_cache_config,
@ -42,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec, SlidingWindowSpec,
UniformTypeKVCacheSpecs, UniformTypeKVCacheSpecs,
) )
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
@ -536,7 +535,7 @@ def test_metrics():
""" """
Test the prefix caching metrics. Test the prefix caching metrics.
""" """
metrics = PrefixCachingMetrics(max_recent_requests=5) metrics = CachingMetrics(max_recent_requests=5)
assert metrics.hit_rate == 0.0 assert metrics.hit_rate == 0.0
metrics.observe(_stats(1, 20, 9)) metrics.observe(_stats(1, 20, 9))
@ -568,7 +567,7 @@ def test_metrics_empty_stats():
""" """
Test the prefix caching metrics with empty stats. Test the prefix caching metrics with empty stats.
""" """
metrics = PrefixCachingMetrics(max_recent_requests=5) metrics = CachingMetrics(max_recent_requests=5)
metrics.observe(_stats(0, 0, 0)) metrics.observe(_stats(0, 0, 0))
metrics.observe(_stats(1, 20, 9)) metrics.observe(_stats(1, 20, 9))
metrics.observe(_stats(0, 0, 0)) metrics.observe(_stats(0, 0, 0))

View File

@ -17,7 +17,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2)) DP_SIZE = int(os.getenv("DP_SIZE", 2))
@ -93,6 +93,7 @@ async def test_load(
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0, engine_idx: int = 0,
): ):
if iteration_stats: if iteration_stats:

View File

@ -354,6 +354,10 @@ class LLM:
else: else:
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer) self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
def reset_mm_cache(self) -> None:
self.processor.clear_mm_cache()
self.llm_engine.reset_mm_cache()
def get_default_sampling_params(self) -> SamplingParams: def get_default_sampling_params(self) -> SamplingParams:
if self.default_sampling_params is None: if self.default_sampling_params is None:
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()

View File

@ -274,6 +274,10 @@ class OpenAIServing:
self.model_config = self.models.model_config self.model_config = self.models.model_config
self.max_model_len = self.model_config.max_model_len self.max_model_len = self.model_config.max_model_len
async def reset_mm_cache(self) -> None:
self.processor.clear_mm_cache()
await self.engine_client.reset_mm_cache()
async def beam_search( async def beam_search(
self, self,
prompt: PromptType, prompt: PromptType,

View File

@ -169,6 +169,10 @@ class ExecutorBase(ABC):
assert s == sets[0], "All workers should have the same LORAs." assert s == sets[0], "All workers should have the same LORAs."
return sets[0] return sets[0]
def reset_mm_cache(self) -> None:
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def start_profile(self) -> None: def start_profile(self) -> None:
self.collective_rpc("start_profile") self.collective_rpc("start_profile")

View File

@ -12,11 +12,8 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -30,16 +27,13 @@ class UniProcExecutor(ExecutorBase):
"""Initialize the worker and load the model.""" """Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args() distributed_init_method, rank, local_rank = self._distributed_args()
is_driver_worker = True
kwargs = dict( kwargs = dict(
vllm_config=self.vllm_config, vllm_config=self.vllm_config,
local_rank=local_rank, local_rank=local_rank,
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker, is_driver_worker=True,
) shared_worker_lock=Lock(),
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config, MULTIMODAL_REGISTRY, Lock()
) )
self.async_output_thread: Optional[ThreadPoolExecutor] = None self.async_output_thread: Optional[ThreadPoolExecutor] = None
@ -74,8 +68,6 @@ class UniProcExecutor(ExecutorBase):
) -> list[Any]: ) -> list[Any]:
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model":
get_and_update_mm_cache(self.mm_receiver_cache, args)
if not non_block: if not non_block:
return [run_method(self.driver_worker, method, args, kwargs)] return [run_method(self.driver_worker, method, args, kwargs)]

View File

@ -19,6 +19,7 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.jsontree import json_iter_leaves from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import ( from .data import (
DecoderOnlyInputs, DecoderOnlyInputs,
@ -56,6 +57,8 @@ class InputPreprocessor:
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> AnyTokenizer:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
@ -664,14 +667,13 @@ class InputPreprocessor:
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
def preprocess( def _preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*, *,
mm_uuids: Optional[MultiModalUUIDDict] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder. # input prompts to encoder & decoder.
@ -694,6 +696,40 @@ class InputPreprocessor:
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
def clear_cache(self) -> None: def preprocess(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
res = self._preprocess(
prompt,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
if self.mm_processor_cache and self.mm_cache_stats is not None:
delta = self.mm_processor_cache.make_stats(delta=True)
self.mm_cache_stats.requests += 1
self.mm_cache_stats.queries += delta.total
self.mm_cache_stats.hits += delta.hits
return res
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
mm_cache_stats = self.mm_cache_stats
if mm_cache_stats is None:
return None
self.mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def clear_mm_cache(self) -> None:
if self.mm_processor_cache is not None: if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache() self.mm_processor_cache.clear_cache()
if self.mm_cache_stats is not None:
self.mm_cache_stats.reset = True

View File

@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.shm_object_storage import (
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import GiB_bytes, MiB_bytes from vllm.utils import GiB_bytes, MiB_bytes
from vllm.utils.cache import LRUCache from vllm.utils.cache import CacheInfo, LRUCache
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
from .inputs import ( from .inputs import (
@ -302,6 +302,16 @@ class BaseMultiModalProcessorCache(
""" """
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
@abstractmethod
def make_stats(self, *, delta: bool = False) -> CacheInfo:
"""
Get (and reset) the multi-modal cache stats.
Returns:
The current multi-modal caching stats.
"""
raise NotImplementedError
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
""" """
@ -347,6 +357,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
def clear_cache(self) -> None: def clear_cache(self) -> None:
self._cache.clear() self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
""" """
@ -397,6 +411,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
def clear_cache(self) -> None: def clear_cache(self) -> None:
self._cache.clear() self._cache.clear()
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._cache.stat(delta=delta)
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
""" """
@ -430,6 +448,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
# cache (prompt_updates, modality) for P0 only # cache (prompt_updates, modality) for P0 only
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {} self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def _stat(self, *, delta: bool = False) -> CacheInfo:
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
@override @override
def is_cached_item(self, mm_hash: str) -> bool: def is_cached_item(self, mm_hash: str) -> bool:
return self._shm_cache.is_cached(mm_hash) return self._shm_cache.is_cached(mm_hash)
@ -441,12 +473,17 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
mm_hash: str, mm_hash: str,
) -> MultiModalProcessorCacheOutItem: ) -> MultiModalProcessorCacheOutItem:
if self._shm_cache.is_cached(mm_hash): if self._shm_cache.is_cached(mm_hash):
self._hits += 1
self._total += 1
address, monotonic_id = self._shm_cache.get_cached(mm_hash) address, monotonic_id = self._shm_cache.get_cached(mm_hash)
prompt_updates, modality = self._p0_cache[mm_hash] prompt_updates, modality = self._p0_cache[mm_hash]
return self.address_as_item(address, monotonic_id, modality), prompt_updates return self.address_as_item(address, monotonic_id, modality), prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}" assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._total += 1
try: try:
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0]) address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
# Try to remove dangling items if p0 cache is too large. # Try to remove dangling items if p0 cache is too large.
@ -469,6 +506,14 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self._shm_cache.clear() self._shm_cache.clear()
self._p0_cache.clear() self._p0_cache.clear()
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
@override
def make_stats(self, *, delta: bool = False) -> CacheInfo:
return self._stat(delta=delta)
def remove_dangling_items(self) -> None: def remove_dangling_items(self) -> None:
"""Remove items that are no longer in the shared memory cache.""" """Remove items that are no longer in the shared memory cache."""
cached_hashes = self._shm_cache.key_index.keys() cached_hashes = self._shm_cache.key_index.keys()

View File

@ -4,7 +4,7 @@
import copy import copy
import os import os
from collections import defaultdict, deque from collections import defaultdict
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, NewType, Optional, Union from typing import Any, Callable, NewType, Optional, Union
@ -23,7 +23,6 @@ from vllm.v1.kv_cache_interface import (
SlidingWindowSpec, SlidingWindowSpec,
UniformTypeKVCacheSpecs, UniformTypeKVCacheSpecs,
) )
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
# BlockHash represents the hash of a single KV-cache block used for # BlockHash represents the hash of a single KV-cache block used for
@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]):
NONE_HASH = BlockHash(hash_fn(hash_seed)) NONE_HASH = BlockHash(hash_fn(hash_seed))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the max recent N requests.
Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000):
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if stats.requests == 0:
return
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while (
len(self.query_queue) > 1
and self.aggregated_requests > self.max_recent_requests
):
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass @dataclass
class KVCacheBlock: class KVCacheBlock:
"""KV-cache block metadata.""" """KV-cache block metadata."""

View File

@ -463,6 +463,7 @@ class AsyncLLM(EngineClient):
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
logger_manager = self.logger_manager logger_manager = self.logger_manager
processor = self.processor
async def output_handler(): async def output_handler():
try: try:
@ -511,6 +512,7 @@ class AsyncLLM(EngineClient):
engine_idx=outputs.engine_index, engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
mm_cache_stats=processor.stat_mm_cache(),
) )
except Exception as e: except Exception as e:
logger.exception("AsyncLLM output_handler failed.") logger.exception("AsyncLLM output_handler failed.")
@ -660,7 +662,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros) await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None: async def reset_mm_cache(self) -> None:
self.processor.clear_cache() self.processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, device: Optional[Device] = None) -> None: async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:

View File

@ -319,7 +319,7 @@ class EngineCore:
) )
engine_core_outputs = self.scheduler.update_from_output( engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output scheduler_output, model_output
) # type: ignore )
return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
@ -400,16 +400,19 @@ class EngineCore:
def reset_mm_cache(self): def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to # NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror) # re-sync the internal caches (P0 sender, P1 receiver)
if self.scheduler.has_unfinished_requests(): if self.scheduler.has_unfinished_requests():
logger.warning( logger.warning(
"Resetting the multi-modal cache when requests are " "Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches." "in progress may lead to desynced internal caches."
) )
# The cache either exists in EngineCore or WorkerWrapperBase
if self.mm_receiver_cache is not None: if self.mm_receiver_cache is not None:
self.mm_receiver_cache.clear_cache() self.mm_receiver_cache.clear_cache()
self.model_executor.reset_mm_cache()
def reset_prefix_cache(self): def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache() self.scheduler.reset_prefix_cache()

View File

@ -306,9 +306,11 @@ class LLMEngine:
# 4) Record stats # 4) Record stats
if self.logger_manager is not None: if self.logger_manager is not None:
assert outputs.scheduler_stats is not None assert outputs.scheduler_stats is not None
self.logger_manager.record( self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
) )
self.do_log_stats_with_interval() self.do_log_stats_with_interval()
@ -321,7 +323,7 @@ class LLMEngine:
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_mm_cache(self): def reset_mm_cache(self):
self.processor.clear_cache() self.processor.clear_mm_cache()
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None): def reset_prefix_cache(self, device: Optional[Device] = None):

View File

@ -21,6 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
from vllm.v1.structured_output.backend_lm_format_enforcer import ( from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer, validate_structured_output_request_lm_format_enforcer,
@ -573,5 +574,8 @@ class Processor:
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens # max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None: def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
self.input_preprocessor.clear_cache() return self.input_preprocessor.stat_mm_cache()
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()

View File

@ -33,8 +33,6 @@ from vllm.distributed.parallel_state import (
get_tp_group, get_tp_group,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import ( from vllm.utils import (
_maybe_force_spawn, _maybe_force_spawn,
decorate_logs, decorate_logs,
@ -46,7 +44,6 @@ from vllm.utils import (
) )
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.executor.utils import get_and_update_mm_cache
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -422,6 +419,7 @@ class WorkerProc:
"rank": rank, "rank": rank,
"distributed_init_method": distributed_init_method, "distributed_init_method": distributed_init_method,
"is_driver_worker": is_driver_worker, "is_driver_worker": is_driver_worker,
"shared_worker_lock": shared_worker_lock,
} }
wrapper.init_worker(all_kwargs) wrapper.init_worker(all_kwargs)
self.worker = wrapper self.worker = wrapper
@ -445,11 +443,6 @@ class WorkerProc:
) )
self.async_output_copy_thread.start() self.async_output_copy_thread.start()
# Initialize multimodal receiver cache if needed
self.mm_receiver_cache = worker_receiver_cache_from_config(
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock
)
# Initialize device # Initialize device
self.worker.init_device() self.worker.init_device()
@ -692,12 +685,7 @@ class WorkerProc:
func = getattr(self.worker, method) func = getattr(self.worker, method)
elif isinstance(method, bytes): elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker) func = partial(cloudpickle.loads(method), self.worker)
# retrieve from shm cache if available
if (
self.mm_receiver_cache is not None
and func.__name__ == "execute_model"
):
get_and_update_mm_cache(self.mm_receiver_cache, args)
output = func(*args, **kwargs) output = func(*args, **kwargs)
except Exception as e: except Exception as e:
# Notes have been introduced in python 3.11 # Notes have been introduced in python 3.11

View File

@ -1,24 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.multimodal.cache import ShmObjectStoreReceiverCache
from vllm.v1.core.sched.output import SchedulerOutput
def get_and_update_mm_cache(
receiver_cache: ShmObjectStoreReceiverCache,
args: tuple[SchedulerOutput],
) -> None:
"""
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
cache as needed.
Args:
receiver_cache: The receiver cache to update.
args: According to the collective_rpc call of execute_model method in
executor, args is a tuple of only one SchedulerOutput element.
"""
scheduler_output = args[0]
for request_data in scheduler_output.scheduled_new_reqs:
request_data.mm_features = receiver_cache.get_and_update_features(
request_data.mm_features
)

View File

@ -11,10 +11,14 @@ import prometheus_client
from vllm.config import SupportsMetricsInfo, VllmConfig from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
from vllm.v1.engine import FinishReason from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import IterationStats, SchedulerStats from vllm.v1.metrics.stats import (
CachingMetrics,
IterationStats,
MultiModalCacheStats,
SchedulerStats,
)
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__) logger = init_logger(__name__)
@ -38,6 +42,7 @@ class StatLoggerBase(ABC):
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0, engine_idx: int = 0,
): ... ): ...
@ -53,10 +58,15 @@ class LoggingStatLogger(StatLoggerBase):
self.engine_index = engine_index self.engine_index = engine_index
self.vllm_config = vllm_config self.vllm_config = vllm_config
self._reset(time.monotonic()) self._reset(time.monotonic())
self.last_scheduler_stats = SchedulerStats() self.last_scheduler_stats = SchedulerStats()
# Prefix cache metrics. This cannot be reset. self.last_mm_cache_stats: Optional[MultiModalCacheStats] = None
# Caching metrics. This cannot be reset.
# TODO: Make the interval configurable. # TODO: Make the interval configurable.
self.prefix_caching_metrics = PrefixCachingMetrics() self.prefix_caching_metrics = CachingMetrics()
self.mm_caching_metrics = CachingMetrics()
self.spec_decoding_logging = SpecDecodingLogging() self.spec_decoding_logging = SpecDecodingLogging()
kv_tranfer_config = self.vllm_config.kv_transfer_config kv_tranfer_config = self.vllm_config.kv_transfer_config
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config) self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
@ -86,6 +96,7 @@ class LoggingStatLogger(StatLoggerBase):
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0, engine_idx: int = 0,
): ):
"""Log Stats to standard output.""" """Log Stats to standard output."""
@ -101,6 +112,11 @@ class LoggingStatLogger(StatLoggerBase):
self.kv_connector_logging.observe(kv_connector_stats) self.kv_connector_logging.observe(kv_connector_stats)
self.last_scheduler_stats = scheduler_stats self.last_scheduler_stats = scheduler_stats
if mm_cache_stats:
self.mm_caching_metrics.observe(mm_cache_stats)
self.last_mm_cache_stats = mm_cache_stats
def log(self): def log(self):
now = time.monotonic() now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
@ -125,21 +141,32 @@ class LoggingStatLogger(StatLoggerBase):
self.last_prompt_throughput = prompt_throughput self.last_prompt_throughput = prompt_throughput
# Format and print output. # Format and print output.
log_fn( log_parts = [
"Engine %03d: " "Avg prompt throughput: %.1f tokens/s",
"Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s",
"Avg generation throughput: %.1f tokens/s, " "Running: %d reqs",
"Running: %d reqs, Waiting: %d reqs, " "Waiting: %d reqs",
"GPU KV cache usage: %.1f%%, " "GPU KV cache usage: %.1f%%",
"Prefix cache hit rate: %.1f%%", "Prefix cache hit rate: %.1f%%",
self.engine_index, ]
log_args = [
prompt_throughput, prompt_throughput,
generation_throughput, generation_throughput,
scheduler_stats.num_running_reqs, scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs, scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100, scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100, self.prefix_caching_metrics.hit_rate * 100,
]
if self.last_mm_cache_stats:
log_parts.append("MM cache hit rate: %.1f%%")
log_args.append(self.mm_caching_metrics.hit_rate * 100)
log_fn(
"Engine %03d: " + ", ".join(log_parts),
self.engine_index,
*log_args,
) )
self.spec_decoding_logging.log(log_fn=log_fn) self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_connector_logging.log(log_fn=log_fn) self.kv_connector_logging.log(log_fn=log_fn)
@ -288,6 +315,32 @@ class PrometheusStatLogger(StatLoggerBase):
counter_prefix_cache_hits, engine_indexes, model_name counter_prefix_cache_hits, engine_indexes, model_name
) )
#
# Multi-modal cache
#
counter_mm_cache_queries = self._counter_cls(
name="vllm:mm_cache_queries",
documentation=(
"Multi-modal cache queries, in terms of number of queried items."
),
labelnames=labelnames,
)
self.counter_mm_cache_queries = make_per_engine(
counter_mm_cache_queries, engine_indexes, model_name
)
counter_mm_cache_hits = self._counter_cls(
name="vllm:mm_cache_hits",
documentation=(
"Multi-modal cache hits, in terms of number of cached items."
),
labelnames=labelnames,
)
self.counter_mm_cache_hits = make_per_engine(
counter_mm_cache_hits, engine_indexes, model_name
)
# #
# Counters # Counters
# #
@ -657,6 +710,7 @@ class PrometheusStatLogger(StatLoggerBase):
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: int = 0, engine_idx: int = 0,
): ):
"""Log to prometheus.""" """Log to prometheus."""
@ -694,6 +748,10 @@ class PrometheusStatLogger(StatLoggerBase):
scheduler_stats.spec_decoding_stats, engine_idx scheduler_stats.spec_decoding_stats, engine_idx
) )
if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
if iteration_stats is None: if iteration_stats is None:
return return
@ -871,6 +929,7 @@ class StatLoggerManager:
self, self,
scheduler_stats: Optional[SchedulerStats], scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
mm_cache_stats: Optional[MultiModalCacheStats] = None,
engine_idx: Optional[int] = None, engine_idx: Optional[int] = None,
): ):
if engine_idx is None: if engine_idx is None:
@ -878,9 +937,19 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx] per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers: for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx) logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx) self.prometheus_logger.record(
scheduler_stats,
iteration_stats,
mm_cache_stats=mm_cache_stats,
engine_idx=engine_idx,
)
def log(self): def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values(): for per_engine_loggers in self.per_engine_logger_dict.values():

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time import time
from collections import deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -13,24 +14,122 @@ if TYPE_CHECKING:
@dataclass @dataclass
class PrefixCacheStats: class BaseCacheStats:
"""Stores prefix cache hit statistics.""" """Stores cache hit statistics."""
# Whether reset_prefix_cache was invoked.
reset: bool = False reset: bool = False
# The number of new requests in this update. """Whether the cache was reset."""
requests: int = 0 requests: int = 0
# The number of queries in these requests. Note that "queries" here """The number of requests in this update."""
# means the number of tokens that were queried from the cache.
queries: int = 0 queries: int = 0
# The number of hits in these requests. """The number of queries in these requests."""
hits: int = 0 hits: int = 0
# The number of previously preempted requests in this update. """The number of hits in these requests."""
class CachingMetrics:
"""Metrics for caching with a hit rate of the most recent N requests.
Args:
interval: The number of the most recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000) -> None:
super().__init__()
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue = deque[tuple[int, int, int]]()
def observe(self, stats: BaseCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `max_recent_requests` requests, the oldest set
of requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# DO NOT appending empty stats to avoid helpful info get kicked out
# due to sliding window.
if stats.requests == 0:
return
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats until number of requests does not exceed
# the limit.
# NOTE: We preserve the latest added stats regardless.
while (
len(self.query_queue) > 1
and self.aggregated_requests > self.max_recent_requests
):
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class PrefixCacheStats(BaseCacheStats):
"""
Stores prefix cache hit statistics.
- `reset`: Whether `reset_prefix_cache` was invoked.
- `queries`: Refers to the number of tokens that were queried.
"""
preempted_requests: int = 0 preempted_requests: int = 0
# The `queries` number for preempted requests. """The number of previously preempted requests in this update."""
preempted_queries: int = 0 preempted_queries: int = 0
# The `hits` number for preempted requests. """The `queries` number for preempted requests."""
preempted_hits: int = 0 preempted_hits: int = 0
"""The `hits` number for preempted requests."""
@dataclass
class MultiModalCacheStats(BaseCacheStats):
"""
Stores multi-modal cache hit statistics.
- `reset`: Whether `reset_mm_cache` was invoked.
- `queries`: Refers to the number of multi-modal data items
that were queried.
"""
@dataclass @dataclass

View File

@ -508,6 +508,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _get_positions(self, num_tokens: Any): def _get_positions(self, num_tokens: Any):
if isinstance(num_tokens, int): if isinstance(num_tokens, int):
if self.uses_mrope: if self.uses_mrope:

View File

@ -442,6 +442,9 @@ class Worker(WorkerBase):
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()

View File

@ -371,6 +371,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
self.sample_from_logits_func = self.sample_from_logits self.sample_from_logits_func = self.sample_from_logits
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
def _update_num_xla_graphs(self, case_str): def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp: if not check_comp:

View File

@ -293,6 +293,9 @@ class TPUWorker:
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.model_runner.get_model() return self.model_runner.get_model()

View File

@ -126,6 +126,10 @@ class MultiModalBudget:
return max_items_per_prompt, max_items_per_batch return max_items_per_prompt, max_items_per_batch
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:

View File

@ -4,7 +4,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import Any, Callable, TypeVar, Union from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -12,7 +12,8 @@ import torch.nn as nn
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils import ( from vllm.utils import (
enable_trace_function_call_for_thread, enable_trace_function_call_for_thread,
resolve_obj_by_qualname, resolve_obj_by_qualname,
@ -21,7 +22,10 @@ from vllm.utils import (
warn_for_unimplemented_methods, warn_for_unimplemented_methods,
) )
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.outputs import SamplerOutput
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__) logger = init_logger(__name__)
@ -103,6 +107,11 @@ class WorkerBase:
"""Initialize the KV cache with the given size in blocks.""" """Initialize the KV cache with the given size in blocks."""
raise NotImplementedError raise NotImplementedError
def reset_mm_cache(self) -> None:
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
if callable(reset_fn):
reset_fn()
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
raise NotImplementedError raise NotImplementedError
@ -114,9 +123,7 @@ class WorkerBase:
"""Load model onto target device.""" """Load model onto target device."""
raise NotImplementedError raise NotImplementedError
def execute_model( def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
self, execute_model_req: ExecuteModelRequest | None = None
) -> list[SamplerOutput] | None:
raise NotImplementedError raise NotImplementedError
def start_worker_execution_loop(self) -> None: def start_worker_execution_loop(self) -> None:
@ -125,11 +132,7 @@ class WorkerBase:
You can stop the loop by executing a driver worker with an empty output. You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details. See `stop_remote_worker_execution_loop` for more details.
""" """
with self.current_platform.inference_mode(): raise NotImplementedError("Dead V0 code")
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and """Determine the number of available blocks for the GPU KV cache and
@ -289,6 +292,28 @@ class WorkerWrapperBase:
worker_class, worker_class,
extended_calls, extended_calls,
) )
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
if shared_worker_lock is None:
msg = (
"Missing `shared_worker_lock` argument from executor. "
"This argument is needed for mm_processor_cache_type='shm'."
)
mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
logger.warning_once(msg)
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
)
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization # To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
@ -323,5 +348,34 @@ class WorkerWrapperBase:
logger.exception(msg) logger.exception(msg)
raise e raise e
def __getattr__(self, attr): def __getattr__(self, attr: str):
return getattr(self.worker, attr) return getattr(self.worker, attr)
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
mm_cache = self.mm_receiver_cache
if mm_cache is None:
return
for req_data in scheduler_output.scheduled_new_reqs:
req_data.mm_features = mm_cache.get_and_update_features(
req_data.mm_features
)
def execute_model(
self,
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None
return self.worker.execute_model(scheduler_output, *args, **kwargs)
def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()
assert self.worker is not None
self.worker.reset_mm_cache()