mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
187 lines
6.4 KiB
Python
187 lines
6.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import regex as re
|
|
import requests
|
|
import torch
|
|
|
|
from tests.utils import RemoteOpenAIServer
|
|
from vllm.v1.worker.utils import bind_kv_cache
|
|
|
|
|
|
def test_bind_kv_cache():
|
|
from vllm.attention import Attention
|
|
|
|
ctx = {
|
|
'layers.0.self_attn': Attention(32, 128, 0.1),
|
|
'layers.1.self_attn': Attention(32, 128, 0.1),
|
|
'layers.2.self_attn': Attention(32, 128, 0.1),
|
|
'layers.3.self_attn': Attention(32, 128, 0.1),
|
|
}
|
|
kv_cache = {
|
|
'layers.0.self_attn': torch.zeros((1, )),
|
|
'layers.1.self_attn': torch.zeros((1, )),
|
|
'layers.2.self_attn': torch.zeros((1, )),
|
|
'layers.3.self_attn': torch.zeros((1, )),
|
|
}
|
|
runner_kv_caches: list[torch.Tensor] = []
|
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
|
'layers.0.self_attn']
|
|
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
|
'layers.1.self_attn']
|
|
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
|
'layers.2.self_attn']
|
|
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
|
'layers.3.self_attn']
|
|
|
|
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
|
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
|
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
|
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
|
|
|
|
|
def test_bind_kv_cache_non_attention():
|
|
from vllm.attention import Attention
|
|
|
|
# example from Jamba PP=2
|
|
ctx = {
|
|
'model.layers.20.attn': Attention(32, 128, 0.1),
|
|
'model.layers.28.attn': Attention(32, 128, 0.1),
|
|
}
|
|
kv_cache = {
|
|
'model.layers.20.attn': torch.zeros((1, )),
|
|
'model.layers.28.attn': torch.zeros((1, )),
|
|
}
|
|
|
|
runner_kv_caches: list[torch.Tensor] = []
|
|
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
|
|
|
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
|
'model.layers.20.attn']
|
|
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
|
'model.layers.28.attn']
|
|
|
|
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
|
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
|
|
|
|
|
# Prometheus metrics utilities for testing
|
|
|
|
|
|
def get_prometheus_metrics(
|
|
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
|
|
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
|
|
|
|
Returns:
|
|
Dict mapping metric names to their values grouped by labels.
|
|
For example: {"vllm:request_success": {
|
|
"engine=0": 5.0, "engine=1": 3.0}
|
|
}
|
|
"""
|
|
try:
|
|
response = requests.get(server.url_for("metrics"), timeout=10)
|
|
response.raise_for_status()
|
|
|
|
metrics: dict[str, dict[str, float]] = {}
|
|
|
|
# Regex patterns for Prometheus metrics
|
|
metric_with_labels = re.compile(
|
|
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
|
|
metric_simple = re.compile(
|
|
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
|
|
|
|
for line in response.text.split('\n'):
|
|
line = line.strip()
|
|
# Skip comments and empty lines
|
|
if not line or line.startswith('#'):
|
|
continue
|
|
|
|
# Try to match metric with labels first
|
|
match = metric_with_labels.match(line)
|
|
if match:
|
|
metric_name, labels_part, value_str = match.groups()
|
|
try:
|
|
value = float(value_str)
|
|
if metric_name not in metrics:
|
|
metrics[metric_name] = {}
|
|
metrics[metric_name][f'{{{labels_part}}}'] = value
|
|
except ValueError:
|
|
continue
|
|
else:
|
|
# Try simple metric without labels
|
|
match = metric_simple.match(line)
|
|
if match:
|
|
metric_name, value_str = match.groups()
|
|
try:
|
|
value = float(value_str)
|
|
if metric_name not in metrics:
|
|
metrics[metric_name] = {}
|
|
metrics[metric_name][''] = value
|
|
except ValueError:
|
|
continue
|
|
|
|
return metrics
|
|
except Exception as e:
|
|
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
|
|
return {}
|
|
|
|
|
|
def get_engine_request_counts(
|
|
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
|
|
"""Extract request counts per engine from Prometheus metrics.
|
|
|
|
Returns:
|
|
Dict mapping engine indices to request counts.
|
|
For example: {"0": 15.0, "1": 12.0}
|
|
"""
|
|
engine_counts = {}
|
|
|
|
# Look for request success metrics with engine labels
|
|
success_metrics = metrics.get("vllm:request_success_total", {})
|
|
engine_pattern = re.compile(r'engine="([^"]*)"')
|
|
|
|
for labels, count in success_metrics.items():
|
|
# Extract engine ID from labels using regex
|
|
match = engine_pattern.search(labels)
|
|
if match:
|
|
engine_id = match.group(1)
|
|
if engine_id not in engine_counts:
|
|
engine_counts[engine_id] = 0.0
|
|
engine_counts[engine_id] += count
|
|
|
|
return engine_counts
|
|
|
|
|
|
def check_request_balancing(server: RemoteOpenAIServer, dp_size: int):
|
|
"""Check request balancing via Prometheus metrics if dp_size > 1.
|
|
|
|
Args:
|
|
server: The RemoteOpenAIServer instance
|
|
dp_size: Number of data parallel ranks
|
|
"""
|
|
if dp_size <= 1:
|
|
return
|
|
|
|
# Get metrics after all requests are completed
|
|
metrics = get_prometheus_metrics(server)
|
|
engine_counts = get_engine_request_counts(metrics)
|
|
|
|
# Check that multiple engines received requests
|
|
engines_with_requests = [
|
|
engine for engine, count in engine_counts.items() if count > 0
|
|
]
|
|
assert len(engines_with_requests) == dp_size, (
|
|
f"Expected requests to be distributed across multiple engines,"
|
|
f" but only engine(s) {engines_with_requests} received "
|
|
f"requests. Engine counts: {engine_counts}")
|
|
|
|
# Verify that the load is reasonably balanced
|
|
# (no engine should handle all requests)
|
|
total_requests = sum(engine_counts.values())
|
|
|
|
for count in engine_counts.values():
|
|
assert count > total_requests // (dp_size + 1), (
|
|
f"requests are imbalanced: {engine_counts}")
|