mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 02:44:26 +08:00
Merge 437ac4e0477dd8ddad8ddccacd7c7965d24bb029 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
85489c9979
@ -1292,8 +1292,12 @@ def test_allocate_with_lookahead():
|
|||||||
|
|
||||||
def test_get_kv_cache_config_one_worker():
|
def test_get_kv_cache_config_one_worker():
|
||||||
# pass max_model_len to pass check_enough_kv_cache_memory
|
# pass max_model_len to pass check_enough_kv_cache_memory
|
||||||
model_config = ModelConfig(max_model_len=16)
|
# Use max_model_len=256 and max_num_batched_tokens=4 so that
|
||||||
|
# full attention layers (16 blocks) >> sliding window layers (2 blocks),
|
||||||
|
# making the overhead calculations work correctly for grouping
|
||||||
|
model_config = ModelConfig(max_model_len=256)
|
||||||
vllm_config = VllmConfig(model_config=model_config)
|
vllm_config = VllmConfig(model_config=model_config)
|
||||||
|
vllm_config.scheduler_config.max_num_batched_tokens = 4
|
||||||
|
|
||||||
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
|
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
|
||||||
# all layers are full attention -> single group
|
# all layers are full attention -> single group
|
||||||
@ -1855,3 +1859,149 @@ def test_auto_fit_max_model_len_not_triggered():
|
|||||||
vllm_config, [kv_cache_specs], [mem_per_block_per_layer * 2 * 32]
|
vllm_config, [kv_cache_specs], [mem_per_block_per_layer * 2 * 32]
|
||||||
)
|
)
|
||||||
assert vllm_config.model_config.max_model_len == 16
|
assert vllm_config.model_config.max_model_len == 16
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindBestGroupSize:
|
||||||
|
"""
|
||||||
|
Tests for the _find_best_group_size function which finds optimal
|
||||||
|
KV cache group sizes while preferring larger groups.
|
||||||
|
|
||||||
|
Key behaviors:
|
||||||
|
- Prefers LARGER group sizes
|
||||||
|
- Enforces group_size >= 3 unless overhead exceeds 10%
|
||||||
|
- Raises ValueError on empty input
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vllm_config(self):
|
||||||
|
"""Create a minimal VllmConfig for testing."""
|
||||||
|
model_config = ModelConfig(max_model_len=4096)
|
||||||
|
return VllmConfig(model_config=model_config)
|
||||||
|
|
||||||
|
def test_empty_input_raises(self, vllm_config):
|
||||||
|
"""Empty input should raise ValueError."""
|
||||||
|
with pytest.raises(ValueError, match="must not be empty"):
|
||||||
|
kv_cache_utils._find_best_group_size({}, vllm_config)
|
||||||
|
|
||||||
|
def test_single_layer_type_returns_layer_count(self, vllm_config):
|
||||||
|
"""Homogeneous layers: group_size == num_layers (single group optimal)."""
|
||||||
|
spec = new_kv_cache_spec()
|
||||||
|
same_type_layers = {spec: [f"layer_{i}" for i in range(5)]}
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# With 5 homogeneous layers, optimal is group_size=5 (one group, no padding)
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
def test_single_layer_returns_one(self, vllm_config):
|
||||||
|
"""Single layer returns 1."""
|
||||||
|
spec = new_kv_cache_spec()
|
||||||
|
same_type_layers = {spec: ["layer_0"]}
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
assert result == 1
|
||||||
|
|
||||||
|
def test_two_layers_returns_two(self, vllm_config):
|
||||||
|
"""Two homogeneous layers -> group_size=2."""
|
||||||
|
spec = new_kv_cache_spec()
|
||||||
|
same_type_layers = {spec: ["layer_0", "layer_1"]}
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# max_layers=2, min_preferred=3 >= max_layers, so returns max_layers=2
|
||||||
|
assert result == 2
|
||||||
|
|
||||||
|
def test_gemma3_pattern_regression(self, vllm_config):
|
||||||
|
"""
|
||||||
|
Regression test: Gemma3-like model with 5:1 sw/full pattern.
|
||||||
|
25 sw + 5 full: group_size=5 gives 0 padding for both.
|
||||||
|
"""
|
||||||
|
full_spec = new_kv_cache_spec()
|
||||||
|
sw_spec = new_sliding_window_spec(sliding_window=512)
|
||||||
|
|
||||||
|
same_type_layers = {
|
||||||
|
sw_spec: [f"sw_{i}" for i in range(25)],
|
||||||
|
full_spec: [f"full_{i}" for i in range(5)],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# GCD(25, 5) = 5, so group_size=5 gives 0 padding
|
||||||
|
# Larger sizes like 25 would give padding for full layers
|
||||||
|
assert result == 5
|
||||||
|
|
||||||
|
def test_llama4_pattern_regression(self, vllm_config):
|
||||||
|
"""
|
||||||
|
Regression test: LLaMA4-like model with 3:1 local/full pattern.
|
||||||
|
24 local + 8 full: group_size=8 gives 0 padding for both.
|
||||||
|
Prefer 8 over 4 because 8 is larger (fewer groups).
|
||||||
|
"""
|
||||||
|
full_spec = new_kv_cache_spec()
|
||||||
|
local_spec = new_sliding_window_spec(sliding_window=256)
|
||||||
|
|
||||||
|
same_type_layers = {
|
||||||
|
local_spec: [f"local_{i}" for i in range(24)],
|
||||||
|
full_spec: [f"full_{i}" for i in range(8)],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# GCD(24, 8) = 8, both 4 and 8 give 0 padding
|
||||||
|
# Prefer 8 (larger group size = fewer groups)
|
||||||
|
assert result == 8
|
||||||
|
|
||||||
|
def test_mixed_20_30_prefers_larger_group(self, vllm_config):
|
||||||
|
"""
|
||||||
|
20 full + 30 sw layers.
|
||||||
|
Both group_size=5 and 10 give zero padding.
|
||||||
|
Prefer 10 because it's larger (fewer groups).
|
||||||
|
"""
|
||||||
|
full_spec = new_kv_cache_spec()
|
||||||
|
sw_spec = new_sliding_window_spec(sliding_window=512)
|
||||||
|
|
||||||
|
same_type_layers = {
|
||||||
|
full_spec: [f"full_{i}" for i in range(20)],
|
||||||
|
sw_spec: [f"sw_{i}" for i in range(30)],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# GCD(20, 30) = 10, both 5 and 10 divide evenly
|
||||||
|
# Prefer 10 (larger = fewer groups)
|
||||||
|
assert result == 10
|
||||||
|
|
||||||
|
def test_eagle_gpt_oss_20b_pattern_regression(self, vllm_config):
|
||||||
|
"""
|
||||||
|
Regression test: GPT-OSS-20B + Eagle pattern (12 sw + 13 full).
|
||||||
|
group_size=13: 1 padding layer for sw (small overhead), 0 for full.
|
||||||
|
This is acceptable overhead, so prefer 13.
|
||||||
|
"""
|
||||||
|
full_spec = new_kv_cache_spec()
|
||||||
|
sw_spec = new_sliding_window_spec(sliding_window=512)
|
||||||
|
|
||||||
|
same_type_layers = {
|
||||||
|
sw_spec: [f"sw_{i}" for i in range(12)],
|
||||||
|
full_spec: [f"full_{i}" for i in range(13)],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# group_size=13: 1 padding for sw, 0 for full
|
||||||
|
# 1 padding out of 25 total = 4% overhead, well under 10%
|
||||||
|
assert result == 13
|
||||||
|
|
||||||
|
def test_fallback_when_overhead_exceeds_threshold(self, vllm_config):
|
||||||
|
"""
|
||||||
|
When enforcing min_group_size >= 3 adds > 10% overhead, fallback to 1.
|
||||||
|
|
||||||
|
Example: 1 full + 5 sw layers.
|
||||||
|
- group_size=1: 0 padding (optimal baseline)
|
||||||
|
- group_size=3: need to pad 2 full layers + 1 sw layer = 3 padding layers
|
||||||
|
That's 3 padding out of 6 total = 10% overhead, way over 10%
|
||||||
|
- group_size=5: need to pad 4 full layers = 4 padding layers
|
||||||
|
That's 4 padding out of 6 total = 67% overhead
|
||||||
|
|
||||||
|
So group_size=1 should be chosen as the fallback.
|
||||||
|
"""
|
||||||
|
full_spec = new_kv_cache_spec()
|
||||||
|
sw_spec = new_sliding_window_spec(sliding_window=512)
|
||||||
|
|
||||||
|
same_type_layers = {
|
||||||
|
full_spec: ["full_0"], # 1 full layer
|
||||||
|
sw_spec: [f"sw_{i}" for i in range(5)], # 5 sw layers
|
||||||
|
}
|
||||||
|
|
||||||
|
result = kv_cache_utils._find_best_group_size(same_type_layers, vllm_config)
|
||||||
|
# group_size >= 3 would add > 10% overhead, so fallback to 1
|
||||||
|
assert result == 1
|
||||||
|
|||||||
@ -946,8 +946,102 @@ def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bo
|
|||||||
return not kv_cache_spec
|
return not kv_cache_spec
|
||||||
|
|
||||||
|
|
||||||
|
def _find_best_group_size(
|
||||||
|
same_type_layers: dict["KVCacheSpec", list[str]],
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
min_preferred_group_size: int = 3,
|
||||||
|
overhead_threshold: float = 0.10,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Find the optimal group size that minimizes padding memory, preferring
|
||||||
|
larger group sizes.
|
||||||
|
|
||||||
|
For each layer type, padding = (group_size - count % group_size) % group_size
|
||||||
|
weighted by that layer's max_memory_usage_bytes. Different layer types
|
||||||
|
contribute differently to total padding based on their actual memory usage
|
||||||
|
(e.g., full attention vs sliding window).
|
||||||
|
|
||||||
|
This function prefers LARGER group sizes. Empirically, small group sizes (1-2)
|
||||||
|
lead to KV cache memory being concentrated in just a few large tensors, which
|
||||||
|
can reduce performance due to memory allocation patterns.
|
||||||
|
|
||||||
|
The algorithm enforces group_size >= min_preferred_group_size (default 3),
|
||||||
|
unless doing so would add more than overhead_threshold (default 10%) extra
|
||||||
|
padding memory compared to the optimal unconstrained group size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
same_type_layers: Dict mapping KVCacheSpec to list of layer names.
|
||||||
|
Must not be empty.
|
||||||
|
vllm_config: The global VllmConfig, used to compute max_memory_usage_bytes
|
||||||
|
min_preferred_group_size: Preferred minimum group size (default 3).
|
||||||
|
Group sizes below this are avoided unless overhead exceeds threshold.
|
||||||
|
overhead_threshold: Maximum allowed overhead ratio (default 0.10 = 10%)
|
||||||
|
before falling back to smaller group sizes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The optimal group size (minimizes padding, ties broken by larger group size)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If same_type_layers is empty
|
||||||
|
"""
|
||||||
|
if not same_type_layers:
|
||||||
|
raise ValueError("same_type_layers must not be empty")
|
||||||
|
|
||||||
|
# Extract (layer_count, max_memory_usage_bytes) per spec
|
||||||
|
# max_memory_usage_bytes properly weights full attention vs sliding window
|
||||||
|
layer_info = [
|
||||||
|
(len(layers), spec.max_memory_usage_bytes(vllm_config))
|
||||||
|
for spec, layers in same_type_layers.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
max_layers = max(count for count, _ in layer_info)
|
||||||
|
total_base_memory = sum(count * mem_size for count, mem_size in layer_info)
|
||||||
|
|
||||||
|
def calc_padding_memory(group_size: int) -> int:
|
||||||
|
"""Total padding memory, weighted by each layer type's memory size."""
|
||||||
|
return sum(
|
||||||
|
((group_size - count % group_size) % group_size) * mem_size
|
||||||
|
for count, mem_size in layer_info
|
||||||
|
)
|
||||||
|
|
||||||
|
def find_best_in_range(start: int, end: int) -> int:
|
||||||
|
"""Find best group size in [start, end] range.
|
||||||
|
|
||||||
|
Prefers larger group sizes when padding is equal.
|
||||||
|
Key: (padding_memory, -group_size) so larger group_size wins ties.
|
||||||
|
"""
|
||||||
|
return min(range(start, end + 1), key=lambda gs: (calc_padding_memory(gs), -gs))
|
||||||
|
|
||||||
|
# Calculate baseline: optimal group size with no minimum constraint
|
||||||
|
baseline_group_size = find_best_in_range(1, max_layers)
|
||||||
|
baseline_padding = calc_padding_memory(baseline_group_size)
|
||||||
|
|
||||||
|
# If preferred minimum is >= max_layers, just use max_layers
|
||||||
|
if min_preferred_group_size >= max_layers:
|
||||||
|
return max_layers
|
||||||
|
|
||||||
|
# Calculate preferred: optimal group size with minimum constraint
|
||||||
|
preferred_group_size = find_best_in_range(min_preferred_group_size, max_layers)
|
||||||
|
preferred_padding = calc_padding_memory(preferred_group_size)
|
||||||
|
|
||||||
|
# Check if enforcing the minimum preference adds too much overhead
|
||||||
|
# Overhead is measured relative to total memory
|
||||||
|
overhead = (
|
||||||
|
(preferred_padding - baseline_padding) / total_base_memory
|
||||||
|
if total_base_memory > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if overhead > overhead_threshold:
|
||||||
|
# Fallback to baseline (allowing smaller group sizes)
|
||||||
|
return baseline_group_size
|
||||||
|
|
||||||
|
return preferred_group_size
|
||||||
|
|
||||||
|
|
||||||
def _get_kv_cache_groups_uniform_page_size(
|
def _get_kv_cache_groups_uniform_page_size(
|
||||||
kv_cache_spec: dict[str, KVCacheSpec],
|
kv_cache_spec: dict[str, KVCacheSpec],
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
) -> list[KVCacheGroupSpec]:
|
) -> list[KVCacheGroupSpec]:
|
||||||
"""
|
"""
|
||||||
Generates the KV cache groups for hybrid models with multiple
|
Generates the KV cache groups for hybrid models with multiple
|
||||||
@ -1023,23 +1117,10 @@ def _get_kv_cache_groups_uniform_page_size(
|
|||||||
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
|
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
|
||||||
# split to 3 groups with 2 layers each:
|
# split to 3 groups with 2 layers each:
|
||||||
# (full.0, full.1), (sw.0, sw.2), (sw.1, padding).
|
# (full.0, full.1), (sw.0, sw.2), (sw.1, padding).
|
||||||
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
|
# Find optimal group_size by trying all options and choosing the one with
|
||||||
# open-source hybrid model follows a n:1 pattern between different attention
|
# minimal padding (weighted by layer memory size). Prefers larger group sizes
|
||||||
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
|
# and enforces group_size >= 3 unless overhead exceeds the threshold.
|
||||||
# full), so we can use the "1" in the n:1 pattern as the group size, which
|
group_size = _find_best_group_size(same_type_layers, vllm_config)
|
||||||
# is the minimum number of layers among all attention types. Need a better
|
|
||||||
# strategy if we want to support more complex patterns (e.g., 20 full + 30
|
|
||||||
# sw, where the group size should be 10).
|
|
||||||
min_num_layers = min([len(layers) for layers in same_type_layers.values()])
|
|
||||||
group_size = min_num_layers
|
|
||||||
max_num_layers = max([len(layers) for layers in same_type_layers.values()])
|
|
||||||
if max_num_layers < min_num_layers * 1.25:
|
|
||||||
# If the number of layers is not much larger than the minimum number of layers,
|
|
||||||
# use the maximum number of layers as the group size to avoid too many padding
|
|
||||||
# layers. A typical example is gpt-oss-20b + eagle, with 12 sw + 13 full. We
|
|
||||||
# pad it to (13 sw, 13 full) instead of (12 sw, 24 full). 1.25 is just a
|
|
||||||
# magic number to avoid too many padding layers.
|
|
||||||
group_size = max_num_layers
|
|
||||||
grouped_layers = []
|
grouped_layers = []
|
||||||
for layers in same_type_layers.values():
|
for layers in same_type_layers.values():
|
||||||
num_padding_layers = group_size - len(layers) % group_size
|
num_padding_layers = group_size - len(layers) % group_size
|
||||||
@ -1245,7 +1326,7 @@ def get_kv_cache_groups(
|
|||||||
# have the same physical memory per block per layer. Split the layers
|
# have the same physical memory per block per layer. Split the layers
|
||||||
# into groups with the same number of layers, and thus same total page
|
# into groups with the same number of layers, and thus same total page
|
||||||
# size.
|
# size.
|
||||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec, vllm_config)
|
||||||
|
|
||||||
|
|
||||||
def generate_scheduler_kv_cache_config(
|
def generate_scheduler_kv_cache_config(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user