mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 17:44:27 +08:00
[Metrics] Log multi-modal cache stats and fix reset (#26285)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
6f0f570c43
commit
ad430a67ca
74
tests/entrypoints/llm/test_mm_cache_stats.py
Normal file
74
tests/entrypoints/llm/test_mm_cache_stats.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.v1.metrics.reader import Counter, Metric
|
||||
|
||||
from ..openai.test_vision import TEST_IMAGE_ASSETS
|
||||
|
||||
|
||||
def _make_messages(image_url: str) -> list[ChatCompletionMessageParam]:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _get_counter_value(metrics: list[Metric], name: str):
|
||||
metric = next(m for m in metrics if m.name == name)
|
||||
assert isinstance(metric, Counter)
|
||||
return metric.value
|
||||
|
||||
|
||||
def _get_mm_cache_stats(metrics: list[Metric]):
|
||||
mm_cache_queries = _get_counter_value(metrics, "vllm:mm_cache_queries")
|
||||
mm_cache_hits = _get_counter_value(metrics, "vllm:mm_cache_hits")
|
||||
|
||||
return mm_cache_queries, mm_cache_hits
|
||||
|
||||
|
||||
@pytest.mark.parametrize("image_urls", [TEST_IMAGE_ASSETS[:2]], indirect=True)
|
||||
@pytest.mark.parametrize("mm_processor_cache_type", ["lru", "shm"])
|
||||
def test_mm_cache_stats(
|
||||
num_gpus_available,
|
||||
image_urls,
|
||||
mm_processor_cache_type,
|
||||
):
|
||||
llm = LLM(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=5,
|
||||
enforce_eager=True,
|
||||
mm_processor_cache_type=mm_processor_cache_type,
|
||||
disable_log_stats=False,
|
||||
limit_mm_per_prompt={"image": 2},
|
||||
)
|
||||
|
||||
llm.chat(_make_messages(image_urls[0]))
|
||||
assert _get_mm_cache_stats(llm.get_metrics()) == (1, 0)
|
||||
|
||||
llm.chat(_make_messages(image_urls[1]))
|
||||
assert _get_mm_cache_stats(llm.get_metrics()) == (2, 0)
|
||||
|
||||
llm.chat(_make_messages(image_urls[0]))
|
||||
assert _get_mm_cache_stats(llm.get_metrics()) == (3, 1)
|
||||
|
||||
# NOTE: This only resets hit rate stats in CachingMetrics
|
||||
# The raw queries and hits counts remain unaffected
|
||||
llm.reset_mm_cache()
|
||||
|
||||
llm.chat(_make_messages(image_urls[0]))
|
||||
assert _get_mm_cache_stats(llm.get_metrics()) == (4, 1)
|
||||
|
||||
llm.chat(_make_messages(image_urls[1]))
|
||||
assert _get_mm_cache_stats(llm.get_metrics()) == (5, 1)
|
||||
@ -18,10 +18,18 @@ from vllm import version
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
MODELS = {
|
||||
"text": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
"multimodal": "HuggingFaceTB/SmolVLM-256M-Instruct",
|
||||
}
|
||||
PREV_MINOR_VERSION = version._prev_minor_version()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=list(MODELS.keys()))
|
||||
def model_key(request):
|
||||
yield request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
@ -45,11 +53,12 @@ def default_server_args():
|
||||
f"--show-hidden-metrics-for-version={PREV_MINOR_VERSION}",
|
||||
],
|
||||
)
|
||||
def server(default_server_args, request):
|
||||
def server(model_key, default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
model_name = MODELS[model_key]
|
||||
with RemoteOpenAIServer(model_name, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@ -60,64 +69,70 @@ async def client(server):
|
||||
|
||||
|
||||
_PROMPT = "Hello my name is Robert and I love magic"
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
_TOKENIZED_PROMPT = tokenizer(_PROMPT)["input_ids"]
|
||||
_IMAGE_URL = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
|
||||
_NUM_REQUESTS = 10
|
||||
_NUM_PROMPT_TOKENS_PER_REQUEST = len(_TOKENIZED_PROMPT)
|
||||
_NUM_GENERATION_TOKENS_PER_REQUEST = 10
|
||||
|
||||
# {metric_family: [(suffix, expected_value)]}
|
||||
EXPECTED_VALUES = {
|
||||
"vllm:time_to_first_token_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:time_per_output_token_seconds": [
|
||||
("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))
|
||||
],
|
||||
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prompt_tokens": [
|
||||
("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS),
|
||||
],
|
||||
"vllm:request_generation_tokens": [
|
||||
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS),
|
||||
],
|
||||
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_params_max_tokens": [
|
||||
("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS),
|
||||
],
|
||||
"vllm:iteration_tokens_total": [
|
||||
(
|
||||
"_sum",
|
||||
_NUM_REQUESTS
|
||||
* (_NUM_PROMPT_TOKENS_PER_REQUEST + _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
),
|
||||
("_count", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
|
||||
],
|
||||
"vllm:prompt_tokens": [("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
|
||||
"vllm:generation_tokens": [
|
||||
("_total", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)
|
||||
],
|
||||
"vllm:request_success": [("_total", _NUM_REQUESTS)],
|
||||
}
|
||||
def _get_expected_values(num_requests: int, prompt_ids: list[int], max_tokens: int):
|
||||
num_prompt_tokens = len(prompt_ids)
|
||||
|
||||
# {metric_family: [(suffix, expected_value)]}
|
||||
return {
|
||||
"vllm:time_to_first_token_seconds": [("_count", num_requests)],
|
||||
"vllm:time_per_output_token_seconds": [
|
||||
("_count", num_requests * (max_tokens - 1))
|
||||
],
|
||||
"vllm:e2e_request_latency_seconds": [("_count", num_requests)],
|
||||
"vllm:request_queue_time_seconds": [("_count", num_requests)],
|
||||
"vllm:request_inference_time_seconds": [("_count", num_requests)],
|
||||
"vllm:request_prefill_time_seconds": [("_count", num_requests)],
|
||||
"vllm:request_decode_time_seconds": [("_count", num_requests)],
|
||||
"vllm:request_prompt_tokens": [
|
||||
("_sum", num_requests * num_prompt_tokens),
|
||||
("_count", num_requests),
|
||||
],
|
||||
"vllm:request_generation_tokens": [
|
||||
("_sum", num_requests * max_tokens),
|
||||
("_count", num_requests),
|
||||
],
|
||||
"vllm:request_params_n": [("_count", num_requests)],
|
||||
"vllm:request_params_max_tokens": [
|
||||
("_sum", num_requests * max_tokens),
|
||||
("_count", num_requests),
|
||||
],
|
||||
"vllm:iteration_tokens_total": [
|
||||
(
|
||||
"_sum",
|
||||
num_requests * (num_prompt_tokens + max_tokens),
|
||||
),
|
||||
("_count", num_requests * max_tokens),
|
||||
],
|
||||
"vllm:prompt_tokens": [("_total", num_requests * num_prompt_tokens)],
|
||||
"vllm:generation_tokens": [("_total", num_requests * max_tokens)],
|
||||
"vllm:request_success": [("_total", num_requests)],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_counts(
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient,
|
||||
model_key: str,
|
||||
):
|
||||
for _ in range(_NUM_REQUESTS):
|
||||
if model_key == "multimodal":
|
||||
pytest.skip("Unnecessary test")
|
||||
|
||||
model_name = MODELS[model_key]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
prompt_ids = tokenizer.encode(_PROMPT)
|
||||
num_requests = 10
|
||||
max_tokens = 10
|
||||
|
||||
for _ in range(num_requests):
|
||||
# sending a request triggers the metrics to be logged.
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=_TOKENIZED_PROMPT,
|
||||
max_tokens=_NUM_GENERATION_TOKENS_PER_REQUEST,
|
||||
model=model_name,
|
||||
prompt=prompt_ids,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
@ -125,8 +140,9 @@ async def test_metrics_counts(
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Loop over all expected metric_families
|
||||
for metric_family, suffix_values_list in EXPECTED_VALUES.items():
|
||||
if (metric_family not in EXPECTED_METRICS_V1) or (
|
||||
expected_values = _get_expected_values(num_requests, prompt_ids, max_tokens)
|
||||
for metric_family, suffix_values_list in expected_values.items():
|
||||
if metric_family not in EXPECTED_METRICS_V1 or (
|
||||
not server.show_hidden_metrics
|
||||
and metric_family in HIDDEN_DEPRECATED_METRICS
|
||||
):
|
||||
@ -217,6 +233,11 @@ EXPECTED_METRICS_V1 = [
|
||||
"vllm:request_decode_time_seconds_count",
|
||||
]
|
||||
|
||||
EXPECTED_METRICS_MM = [
|
||||
"vllm:mm_cache_queries",
|
||||
"vllm:mm_cache_hits",
|
||||
]
|
||||
|
||||
HIDDEN_DEPRECATED_METRICS: list[str] = [
|
||||
"vllm:gpu_cache_usage_perc",
|
||||
"vllm:gpu_prefix_cache_queries",
|
||||
@ -231,19 +252,43 @@ HIDDEN_DEPRECATED_METRICS: list[str] = [
|
||||
async def test_metrics_exist(
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient,
|
||||
model_key: str,
|
||||
):
|
||||
model_name = MODELS[model_key]
|
||||
|
||||
# sending a request triggers the metrics to be logged.
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
if model_key == "text":
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
else:
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": _IMAGE_URL}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
for metric in EXPECTED_METRICS_V1:
|
||||
expected_metrics = EXPECTED_METRICS_V1
|
||||
if model_key == "multimodal":
|
||||
# NOTE: Don't use in-place assignment
|
||||
expected_metrics = expected_metrics + EXPECTED_METRICS_MM
|
||||
|
||||
for metric in expected_metrics:
|
||||
if metric in HIDDEN_DEPRECATED_METRICS and not server.show_hidden_metrics:
|
||||
continue
|
||||
assert metric in response.text
|
||||
@ -253,9 +298,14 @@ async def test_metrics_exist(
|
||||
async def test_abort_metrics_reset(
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient,
|
||||
model_key: str,
|
||||
):
|
||||
model_name = MODELS[model_key]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
prompt_ids = tokenizer.encode(_PROMPT)
|
||||
|
||||
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
|
||||
server
|
||||
server,
|
||||
)
|
||||
|
||||
# Expect no running requests or kvcache usage
|
||||
@ -268,8 +318,8 @@ async def test_abort_metrics_reset(
|
||||
for _ in range(3):
|
||||
task = asyncio.create_task(
|
||||
client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=_TOKENIZED_PROMPT,
|
||||
model=model_name,
|
||||
prompt=prompt_ids,
|
||||
max_tokens=100, # Long generation to give time to abort
|
||||
temperature=0.0,
|
||||
)
|
||||
@ -281,7 +331,7 @@ async def test_abort_metrics_reset(
|
||||
|
||||
# Check that we have running requests
|
||||
running_requests, waiting_requests, kv_cache_usage = _get_running_metrics_from_api(
|
||||
server
|
||||
server,
|
||||
)
|
||||
|
||||
# Expect running requests and kvcache usage
|
||||
|
||||
@ -20,7 +20,6 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
PrefixCachingMetrics,
|
||||
estimate_max_model_len,
|
||||
generate_block_hash_extra_keys,
|
||||
generate_scheduler_kv_cache_config,
|
||||
@ -42,7 +41,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.metrics.stats import CachingMetrics, PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
@ -536,7 +535,7 @@ def test_metrics():
|
||||
"""
|
||||
Test the prefix caching metrics.
|
||||
"""
|
||||
metrics = PrefixCachingMetrics(max_recent_requests=5)
|
||||
metrics = CachingMetrics(max_recent_requests=5)
|
||||
assert metrics.hit_rate == 0.0
|
||||
|
||||
metrics.observe(_stats(1, 20, 9))
|
||||
@ -568,7 +567,7 @@ def test_metrics_empty_stats():
|
||||
"""
|
||||
Test the prefix caching metrics with empty stats.
|
||||
"""
|
||||
metrics = PrefixCachingMetrics(max_recent_requests=5)
|
||||
metrics = CachingMetrics(max_recent_requests=5)
|
||||
metrics.observe(_stats(0, 0, 0))
|
||||
metrics.observe(_stats(1, 20, 9))
|
||||
metrics.observe(_stats(0, 0, 0))
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import DPAsyncMPClient
|
||||
from vllm.v1.metrics.loggers import StatLoggerBase
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.metrics.stats import IterationStats, MultiModalCacheStats, SchedulerStats
|
||||
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", 2))
|
||||
|
||||
@ -93,6 +93,7 @@ async def test_load(
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
mm_cache_stats: Optional[MultiModalCacheStats] = None,
|
||||
engine_idx: int = 0,
|
||||
):
|
||||
if iteration_stats:
|
||||
|
||||
@ -354,6 +354,10 @@ class LLM:
|
||||
else:
|
||||
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.processor.clear_mm_cache()
|
||||
self.llm_engine.reset_mm_cache()
|
||||
|
||||
def get_default_sampling_params(self) -> SamplingParams:
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
|
||||
@ -274,6 +274,10 @@ class OpenAIServing:
|
||||
self.model_config = self.models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.clear_mm_cache()
|
||||
await self.engine_client.reset_mm_cache()
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
|
||||
@ -169,6 +169,10 @@ class ExecutorBase(ABC):
|
||||
assert s == sets[0], "All workers should have the same LORAs."
|
||||
return sets[0]
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
"""Reset the multi-modal cache in each worker."""
|
||||
self.collective_rpc("reset_mm_cache")
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.collective_rpc("start_profile")
|
||||
|
||||
|
||||
@ -12,11 +12,8 @@ import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
@ -30,16 +27,13 @@ class UniProcExecutor(ExecutorBase):
|
||||
"""Initialize the worker and load the model."""
|
||||
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
|
||||
distributed_init_method, rank, local_rank = self._distributed_args()
|
||||
is_driver_worker = True
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock()
|
||||
is_driver_worker=True,
|
||||
shared_worker_lock=Lock(),
|
||||
)
|
||||
|
||||
self.async_output_thread: Optional[ThreadPoolExecutor] = None
|
||||
@ -74,8 +68,6 @@ class UniProcExecutor(ExecutorBase):
|
||||
) -> list[Any]:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if self.mm_receiver_cache is not None and method == "execute_model":
|
||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||
|
||||
if not non_block:
|
||||
return [run_method(self.driver_worker, method, args, kwargs)]
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.multimodal.inputs import (
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
|
||||
from .data import (
|
||||
DecoderOnlyInputs,
|
||||
@ -56,6 +57,8 @@ class InputPreprocessor:
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_processor_cache
|
||||
|
||||
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
|
||||
|
||||
def get_tokenizer(self) -> AnyTokenizer:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
@ -664,14 +667,13 @@ class InputPreprocessor:
|
||||
|
||||
return self._build_decoder_only_llm_inputs(prompt_comps)
|
||||
|
||||
def preprocess(
|
||||
def _preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.model_config.is_encoder_decoder:
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder.
|
||||
@ -694,6 +696,40 @@ class InputPreprocessor:
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
res = self._preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
if self.mm_processor_cache and self.mm_cache_stats is not None:
|
||||
delta = self.mm_processor_cache.make_stats(delta=True)
|
||||
self.mm_cache_stats.requests += 1
|
||||
self.mm_cache_stats.queries += delta.total
|
||||
self.mm_cache_stats.hits += delta.hits
|
||||
|
||||
return res
|
||||
|
||||
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
|
||||
mm_cache_stats = self.mm_cache_stats
|
||||
if mm_cache_stats is None:
|
||||
return None
|
||||
|
||||
self.mm_cache_stats = MultiModalCacheStats()
|
||||
|
||||
return mm_cache_stats
|
||||
|
||||
def clear_mm_cache(self) -> None:
|
||||
if self.mm_processor_cache is not None:
|
||||
self.mm_processor_cache.clear_cache()
|
||||
|
||||
if self.mm_cache_stats is not None:
|
||||
self.mm_cache_stats.reset = True
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, MiB_bytes
|
||||
from vllm.utils.cache import LRUCache
|
||||
from vllm.utils.cache import CacheInfo, LRUCache
|
||||
from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves
|
||||
|
||||
from .inputs import (
|
||||
@ -302,6 +302,16 @@ class BaseMultiModalProcessorCache(
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
"""
|
||||
Get (and reset) the multi-modal cache stats.
|
||||
|
||||
Returns:
|
||||
The current multi-modal caching stats.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
@ -347,6 +357,10 @@ class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
@ -397,6 +411,10 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._cache.stat(delta=delta)
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
@ -430,6 +448,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
# cache (prompt_updates, modality) for P0 only
|
||||
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate], str]] = {}
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
def _stat(self, *, delta: bool = False) -> CacheInfo:
|
||||
info = CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
if delta:
|
||||
info_delta = info - self._last_info
|
||||
self._last_info = info
|
||||
info = info_delta
|
||||
|
||||
return info
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
@ -441,12 +473,17 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
self._hits += 1
|
||||
self._total += 1
|
||||
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates, modality = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id, modality), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._total += 1
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
@ -469,6 +506,14 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
@override
|
||||
def make_stats(self, *, delta: bool = False) -> CacheInfo:
|
||||
return self._stat(delta=delta)
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
import copy
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, NewType, Optional, Union
|
||||
@ -23,7 +23,6 @@ from vllm.v1.kv_cache_interface import (
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# BlockHash represents the hash of a single KV-cache block used for
|
||||
@ -101,78 +100,6 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]):
|
||||
NONE_HASH = BlockHash(hash_fn(hash_seed))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
"""Metrics for prefix caching with a hit rate of the max recent N requests.
|
||||
|
||||
Args:
|
||||
max_recent_requests: The number of the max recent requests to aggregate.
|
||||
Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, max_recent_requests: int = 1000):
|
||||
self.max_recent_requests = max_recent_requests
|
||||
# The current aggregated values.
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
# A deque of (requests, queries, hits) for the most recent requests.
|
||||
self.query_queue: deque[tuple[int, int, int]] = deque()
|
||||
|
||||
def observe(self, stats: PrefixCacheStats):
|
||||
"""Observe the prefix caching for a set of requests.
|
||||
|
||||
This function is called with information gathered when new requests
|
||||
are being scheduled and are looking for computed blocks.
|
||||
|
||||
When there are more than `max_recent_requests` requests, the oldest set
|
||||
of requests are removed from the metrics.
|
||||
|
||||
Args:
|
||||
stats: The prefix cache stats.
|
||||
"""
|
||||
# reset_prefix_cache was invoked before the current update.
|
||||
# Reset the metrics before aggregating the current stats.
|
||||
if stats.reset:
|
||||
self.reset()
|
||||
|
||||
# DO NOT appending empty stats to avoid helpful info get kicked out
|
||||
# due to sliding window.
|
||||
if stats.requests == 0:
|
||||
return
|
||||
|
||||
# Update the metrics.
|
||||
self.query_queue.append((stats.requests, stats.queries, stats.hits))
|
||||
self.aggregated_requests += stats.requests
|
||||
self.aggregated_query_total += stats.queries
|
||||
self.aggregated_query_hit += stats.hits
|
||||
|
||||
# Remove the oldest stats until number of requests does not exceed
|
||||
# the limit.
|
||||
# NOTE: We preserve the latest added stats regardless.
|
||||
while (
|
||||
len(self.query_queue) > 1
|
||||
and self.aggregated_requests > self.max_recent_requests
|
||||
):
|
||||
old_requests, old_queries, old_hits = self.query_queue.popleft()
|
||||
self.aggregated_requests -= old_requests
|
||||
self.aggregated_query_total -= old_queries
|
||||
self.aggregated_query_hit -= old_hits
|
||||
|
||||
def reset(self):
|
||||
"""Reset the metrics."""
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
self.query_queue.clear()
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate the hit rate for the past N requests."""
|
||||
if self.aggregated_query_total == 0:
|
||||
return 0.0
|
||||
return self.aggregated_query_hit / self.aggregated_query_total
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlock:
|
||||
"""KV-cache block metadata."""
|
||||
|
||||
@ -463,6 +463,7 @@ class AsyncLLM(EngineClient):
|
||||
output_processor = self.output_processor
|
||||
log_stats = self.log_stats
|
||||
logger_manager = self.logger_manager
|
||||
processor = self.processor
|
||||
|
||||
async def output_handler():
|
||||
try:
|
||||
@ -511,6 +512,7 @@ class AsyncLLM(EngineClient):
|
||||
engine_idx=outputs.engine_index,
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=processor.stat_mm_cache(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("AsyncLLM output_handler failed.")
|
||||
@ -660,7 +662,7 @@ class AsyncLLM(EngineClient):
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.clear_cache()
|
||||
self.processor.clear_mm_cache()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(self, device: Optional[Device] = None) -> None:
|
||||
|
||||
@ -319,7 +319,7 @@ class EngineCore:
|
||||
)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
@ -400,16 +400,19 @@ class EngineCore:
|
||||
|
||||
def reset_mm_cache(self):
|
||||
# NOTE: Since this is mainly for debugging, we don't attempt to
|
||||
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
|
||||
# re-sync the internal caches (P0 sender, P1 receiver)
|
||||
if self.scheduler.has_unfinished_requests():
|
||||
logger.warning(
|
||||
"Resetting the multi-modal cache when requests are "
|
||||
"in progress may lead to desynced internal caches."
|
||||
)
|
||||
|
||||
# The cache either exists in EngineCore or WorkerWrapperBase
|
||||
if self.mm_receiver_cache is not None:
|
||||
self.mm_receiver_cache.clear_cache()
|
||||
|
||||
self.model_executor.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
|
||||
|
||||
@ -306,9 +306,11 @@ class LLMEngine:
|
||||
# 4) Record stats
|
||||
if self.logger_manager is not None:
|
||||
assert outputs.scheduler_stats is not None
|
||||
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.processor.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
|
||||
@ -321,7 +323,7 @@ class LLMEngine:
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.processor.clear_cache()
|
||||
self.processor.clear_mm_cache()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None):
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
from vllm.v1.structured_output.backend_guidance import validate_guidance_grammar
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
validate_structured_output_request_lm_format_enforcer,
|
||||
@ -573,5 +574,8 @@ class Processor:
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.input_preprocessor.clear_cache()
|
||||
def stat_mm_cache(self) -> Optional[MultiModalCacheStats]:
|
||||
return self.input_preprocessor.stat_mm_cache()
|
||||
|
||||
def clear_mm_cache(self) -> None:
|
||||
self.input_preprocessor.clear_mm_cache()
|
||||
|
||||
@ -33,8 +33,6 @@ from vllm.distributed.parallel_state import (
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (
|
||||
_maybe_force_spawn,
|
||||
decorate_logs,
|
||||
@ -46,7 +44,6 @@ from vllm.utils import (
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
@ -422,6 +419,7 @@ class WorkerProc:
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"is_driver_worker": is_driver_worker,
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
}
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
@ -445,11 +443,6 @@ class WorkerProc:
|
||||
)
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize multimodal receiver cache if needed
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock
|
||||
)
|
||||
|
||||
# Initialize device
|
||||
self.worker.init_device()
|
||||
|
||||
@ -692,12 +685,7 @@ class WorkerProc:
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
# retrieve from shm cache if available
|
||||
if (
|
||||
self.mm_receiver_cache is not None
|
||||
and func.__name__ == "execute_model"
|
||||
):
|
||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Notes have been introduced in python 3.11
|
||||
|
||||
@ -1,24 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.multimodal.cache import ShmObjectStoreReceiverCache
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
def get_and_update_mm_cache(
|
||||
receiver_cache: ShmObjectStoreReceiverCache,
|
||||
args: tuple[SchedulerOutput],
|
||||
) -> None:
|
||||
"""
|
||||
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
|
||||
cache as needed.
|
||||
|
||||
Args:
|
||||
receiver_cache: The receiver cache to update.
|
||||
args: According to the collective_rpc call of execute_model method in
|
||||
executor, args is a tuple of only one SchedulerOutput element.
|
||||
"""
|
||||
scheduler_output = args[0]
|
||||
for request_data in scheduler_output.scheduled_new_reqs:
|
||||
request_data.mm_features = receiver_cache.get_and_update_features(
|
||||
request_data.mm_features
|
||||
)
|
||||
@ -11,10 +11,14 @@ import prometheus_client
|
||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
|
||||
from vllm.v1.engine import FinishReason
|
||||
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.metrics.stats import (
|
||||
CachingMetrics,
|
||||
IterationStats,
|
||||
MultiModalCacheStats,
|
||||
SchedulerStats,
|
||||
)
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -38,6 +42,7 @@ class StatLoggerBase(ABC):
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
mm_cache_stats: Optional[MultiModalCacheStats] = None,
|
||||
engine_idx: int = 0,
|
||||
): ...
|
||||
|
||||
@ -53,10 +58,15 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.engine_index = engine_index
|
||||
self.vllm_config = vllm_config
|
||||
self._reset(time.monotonic())
|
||||
|
||||
self.last_scheduler_stats = SchedulerStats()
|
||||
# Prefix cache metrics. This cannot be reset.
|
||||
self.last_mm_cache_stats: Optional[MultiModalCacheStats] = None
|
||||
|
||||
# Caching metrics. This cannot be reset.
|
||||
# TODO: Make the interval configurable.
|
||||
self.prefix_caching_metrics = PrefixCachingMetrics()
|
||||
self.prefix_caching_metrics = CachingMetrics()
|
||||
self.mm_caching_metrics = CachingMetrics()
|
||||
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
kv_tranfer_config = self.vllm_config.kv_transfer_config
|
||||
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
|
||||
@ -86,6 +96,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
mm_cache_stats: Optional[MultiModalCacheStats] = None,
|
||||
engine_idx: int = 0,
|
||||
):
|
||||
"""Log Stats to standard output."""
|
||||
@ -101,6 +112,11 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.kv_connector_logging.observe(kv_connector_stats)
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
if mm_cache_stats:
|
||||
self.mm_caching_metrics.observe(mm_cache_stats)
|
||||
|
||||
self.last_mm_cache_stats = mm_cache_stats
|
||||
|
||||
def log(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
@ -125,21 +141,32 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"Engine %03d: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
log_parts = [
|
||||
"Avg prompt throughput: %.1f tokens/s",
|
||||
"Avg generation throughput: %.1f tokens/s",
|
||||
"Running: %d reqs",
|
||||
"Waiting: %d reqs",
|
||||
"GPU KV cache usage: %.1f%%",
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
self.engine_index,
|
||||
]
|
||||
log_args = [
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
]
|
||||
if self.last_mm_cache_stats:
|
||||
log_parts.append("MM cache hit rate: %.1f%%")
|
||||
log_args.append(self.mm_caching_metrics.hit_rate * 100)
|
||||
|
||||
log_fn(
|
||||
"Engine %03d: " + ", ".join(log_parts),
|
||||
self.engine_index,
|
||||
*log_args,
|
||||
)
|
||||
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
self.kv_connector_logging.log(log_fn=log_fn)
|
||||
|
||||
@ -288,6 +315,32 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
counter_prefix_cache_hits, engine_indexes, model_name
|
||||
)
|
||||
|
||||
#
|
||||
# Multi-modal cache
|
||||
#
|
||||
|
||||
counter_mm_cache_queries = self._counter_cls(
|
||||
name="vllm:mm_cache_queries",
|
||||
documentation=(
|
||||
"Multi-modal cache queries, in terms of number of queried items."
|
||||
),
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_mm_cache_queries = make_per_engine(
|
||||
counter_mm_cache_queries, engine_indexes, model_name
|
||||
)
|
||||
|
||||
counter_mm_cache_hits = self._counter_cls(
|
||||
name="vllm:mm_cache_hits",
|
||||
documentation=(
|
||||
"Multi-modal cache hits, in terms of number of cached items."
|
||||
),
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_mm_cache_hits = make_per_engine(
|
||||
counter_mm_cache_hits, engine_indexes, model_name
|
||||
)
|
||||
|
||||
#
|
||||
# Counters
|
||||
#
|
||||
@ -657,6 +710,7 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
mm_cache_stats: Optional[MultiModalCacheStats] = None,
|
||||
engine_idx: int = 0,
|
||||
):
|
||||
"""Log to prometheus."""
|
||||
@ -694,6 +748,10 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
scheduler_stats.spec_decoding_stats, engine_idx
|
||||
)
|
||||
|
||||
if mm_cache_stats is not None:
|
||||
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
|
||||
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
|
||||
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
@ -871,6 +929,7 @@ class StatLoggerManager:
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
mm_cache_stats: Optional[MultiModalCacheStats] = None,
|
||||
engine_idx: Optional[int] = None,
|
||||
):
|
||||
if engine_idx is None:
|
||||
@ -878,9 +937,19 @@ class StatLoggerManager:
|
||||
|
||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||
for logger in per_engine_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
logger.record(
|
||||
scheduler_stats,
|
||||
iteration_stats,
|
||||
mm_cache_stats=mm_cache_stats,
|
||||
engine_idx=engine_idx,
|
||||
)
|
||||
|
||||
self.prometheus_logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
self.prometheus_logger.record(
|
||||
scheduler_stats,
|
||||
iteration_stats,
|
||||
mm_cache_stats=mm_cache_stats,
|
||||
engine_idx=engine_idx,
|
||||
)
|
||||
|
||||
def log(self):
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@ -13,24 +14,122 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefixCacheStats:
|
||||
"""Stores prefix cache hit statistics."""
|
||||
class BaseCacheStats:
|
||||
"""Stores cache hit statistics."""
|
||||
|
||||
# Whether reset_prefix_cache was invoked.
|
||||
reset: bool = False
|
||||
# The number of new requests in this update.
|
||||
"""Whether the cache was reset."""
|
||||
|
||||
requests: int = 0
|
||||
# The number of queries in these requests. Note that "queries" here
|
||||
# means the number of tokens that were queried from the cache.
|
||||
"""The number of requests in this update."""
|
||||
|
||||
queries: int = 0
|
||||
# The number of hits in these requests.
|
||||
"""The number of queries in these requests."""
|
||||
|
||||
hits: int = 0
|
||||
# The number of previously preempted requests in this update.
|
||||
"""The number of hits in these requests."""
|
||||
|
||||
|
||||
class CachingMetrics:
|
||||
"""Metrics for caching with a hit rate of the most recent N requests.
|
||||
Args:
|
||||
interval: The number of the most recent requests to aggregate.
|
||||
Defaults to 1000.
|
||||
"""
|
||||
|
||||
def __init__(self, max_recent_requests: int = 1000) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_recent_requests = max_recent_requests
|
||||
# The current aggregated values.
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
|
||||
# A deque of (requests, queries, hits) for the most recent requests.
|
||||
self.query_queue = deque[tuple[int, int, int]]()
|
||||
|
||||
def observe(self, stats: BaseCacheStats):
|
||||
"""Observe the prefix caching for a set of requests.
|
||||
|
||||
This function is called with information gathered when new requests
|
||||
are being scheduled and are looking for computed blocks.
|
||||
|
||||
When there are more than `max_recent_requests` requests, the oldest set
|
||||
of requests are removed from the metrics.
|
||||
|
||||
Args:
|
||||
stats: The prefix cache stats.
|
||||
"""
|
||||
# reset_prefix_cache was invoked before the current update.
|
||||
# Reset the metrics before aggregating the current stats.
|
||||
if stats.reset:
|
||||
self.reset()
|
||||
|
||||
# DO NOT appending empty stats to avoid helpful info get kicked out
|
||||
# due to sliding window.
|
||||
if stats.requests == 0:
|
||||
return
|
||||
|
||||
# Update the metrics.
|
||||
self.query_queue.append((stats.requests, stats.queries, stats.hits))
|
||||
self.aggregated_requests += stats.requests
|
||||
self.aggregated_query_total += stats.queries
|
||||
self.aggregated_query_hit += stats.hits
|
||||
|
||||
# Remove the oldest stats until number of requests does not exceed
|
||||
# the limit.
|
||||
# NOTE: We preserve the latest added stats regardless.
|
||||
while (
|
||||
len(self.query_queue) > 1
|
||||
and self.aggregated_requests > self.max_recent_requests
|
||||
):
|
||||
old_requests, old_queries, old_hits = self.query_queue.popleft()
|
||||
self.aggregated_requests -= old_requests
|
||||
self.aggregated_query_total -= old_queries
|
||||
self.aggregated_query_hit -= old_hits
|
||||
|
||||
def reset(self):
|
||||
"""Reset the metrics."""
|
||||
self.aggregated_requests = 0
|
||||
self.aggregated_query_total = 0
|
||||
self.aggregated_query_hit = 0
|
||||
self.query_queue.clear()
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
"""Calculate the hit rate for the past N requests."""
|
||||
if self.aggregated_query_total == 0:
|
||||
return 0.0
|
||||
return self.aggregated_query_hit / self.aggregated_query_total
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrefixCacheStats(BaseCacheStats):
|
||||
"""
|
||||
Stores prefix cache hit statistics.
|
||||
- `reset`: Whether `reset_prefix_cache` was invoked.
|
||||
- `queries`: Refers to the number of tokens that were queried.
|
||||
"""
|
||||
|
||||
preempted_requests: int = 0
|
||||
# The `queries` number for preempted requests.
|
||||
"""The number of previously preempted requests in this update."""
|
||||
|
||||
preempted_queries: int = 0
|
||||
# The `hits` number for preempted requests.
|
||||
"""The `queries` number for preempted requests."""
|
||||
|
||||
preempted_hits: int = 0
|
||||
"""The `hits` number for preempted requests."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalCacheStats(BaseCacheStats):
|
||||
"""
|
||||
Stores multi-modal cache hit statistics.
|
||||
- `reset`: Whether `reset_mm_cache` was invoked.
|
||||
- `queries`: Refers to the number of multi-modal data items
|
||||
that were queried.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -508,6 +508,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
|
||||
def _get_positions(self, num_tokens: Any):
|
||||
if isinstance(num_tokens, int):
|
||||
if self.uses_mrope:
|
||||
|
||||
@ -442,6 +442,9 @@ class Worker(WorkerBase):
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.model_runner.reset_mm_cache()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
|
||||
@ -371,6 +371,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
self.sample_from_logits_func = self.sample_from_logits
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
|
||||
def _update_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
|
||||
@ -293,6 +293,9 @@ class TPUWorker:
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.model_runner.reset_mm_cache()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
|
||||
@ -126,6 +126,10 @@ class MultiModalBudget:
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
if self.cache is not None:
|
||||
self.cache.clear_cache()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -12,7 +12,8 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (
|
||||
enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname,
|
||||
@ -21,7 +22,10 @@ from vllm.utils import (
|
||||
warn_for_unimplemented_methods,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -103,6 +107,11 @@ class WorkerBase:
|
||||
"""Initialize the KV cache with the given size in blocks."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
|
||||
if callable(reset_fn):
|
||||
reset_fn()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -114,9 +123,7 @@ class WorkerBase:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest | None = None
|
||||
) -> list[SamplerOutput] | None:
|
||||
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
@ -125,11 +132,7 @@ class WorkerBase:
|
||||
You can stop the loop by executing a driver worker with an empty output.
|
||||
See `stop_remote_worker_execution_loop` for more details.
|
||||
"""
|
||||
with self.current_platform.inference_mode():
|
||||
while True:
|
||||
output = self.execute_model(execute_model_req=None)
|
||||
if output is None:
|
||||
return None
|
||||
raise NotImplementedError("Dead V0 code")
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
@ -289,6 +292,28 @@ class WorkerWrapperBase:
|
||||
worker_class,
|
||||
extended_calls,
|
||||
)
|
||||
|
||||
shared_worker_lock = kwargs.pop("shared_worker_lock", None)
|
||||
if shared_worker_lock is None:
|
||||
msg = (
|
||||
"Missing `shared_worker_lock` argument from executor. "
|
||||
"This argument is needed for mm_processor_cache_type='shm'."
|
||||
)
|
||||
|
||||
mm_config = self.vllm_config.model_config.multimodal_config
|
||||
if mm_config and mm_config.mm_processor_cache_type == "shm":
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
logger.warning_once(msg)
|
||||
|
||||
self.mm_receiver_cache = None
|
||||
else:
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
self.vllm_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
shared_worker_lock,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
self.worker = worker_class(**kwargs)
|
||||
@ -323,5 +348,34 @@ class WorkerWrapperBase:
|
||||
logger.exception(msg)
|
||||
raise e
|
||||
|
||||
def __getattr__(self, attr):
|
||||
def __getattr__(self, attr: str):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
|
||||
mm_cache = self.mm_receiver_cache
|
||||
if mm_cache is None:
|
||||
return
|
||||
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_data.mm_features = mm_cache.get_and_update_features(
|
||||
req_data.mm_features
|
||||
)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> ModelRunnerOutput:
|
||||
self._apply_mm_cache(scheduler_output)
|
||||
|
||||
assert self.worker is not None
|
||||
return self.worker.execute_model(scheduler_output, *args, **kwargs)
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
mm_receiver_cache = self.mm_receiver_cache
|
||||
if mm_receiver_cache is not None:
|
||||
mm_receiver_cache.clear_cache()
|
||||
|
||||
assert self.worker is not None
|
||||
self.worker.reset_mm_cache()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user