[V1] [Hybrid] Add new test to verify that hybrid views into KVCacheTensor are compatible (#21300)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-07-22 08:31:18 +02:00 committed by GitHub
parent af376ca19d
commit 488d8a986a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,15 +3,19 @@
import random import random
import numpy as np
import pytest import pytest
import torch import torch
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig, set_current_vllm_config) SchedulerConfig, VllmConfig, set_current_vllm_config)
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes, update_environment_variables
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config) get_kv_cache_config)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
@ -686,3 +690,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
'''
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
(via _reshape_kv_cache_tensors function). This test verifies
that the views are compatible: writing a mamba block
will not corrupt an attention block and vice-versa
'''
current_platform.seed_everything(42)
update_environment_variables({
'RANK': "0",
'LOCAL_RANK': "0",
'WORLD_SIZE': "1",
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16)
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig(
model="ibm-granite/granite-4.0-tiny-preview",
dtype="float16",
)
cache_config = CacheConfig(
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
layer_0 = "model.layers.0.self_attn.attn"
layer_1 = "model.layers.1.self_attn.attn"
layer_2 = "model.layers.2.mixer"
layer_3 = "model.layers.3.mixer"
layer_4 = "model.layers.4.mixer"
layer_5 = "model.layers.5.mixer"
with set_current_vllm_config(vllm_config):
hf_config = vllm_config.model_config.hf_config
fwd_context = {}
for key in [layer_0, layer_1]:
fwd_context[key] = Attention(
num_heads=model_config.get_num_attention_heads(
parallel_config),
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
scale=1.0,
prefix=key,
)
for key in [layer_2, layer_3, layer_4, layer_5]:
fwd_context[key] = MambaMixer2(
hidden_size = hf_config.hidden_size,
ssm_state_size = hf_config.mamba_d_state,
conv_kernel_size = hf_config.mamba_d_conv,
intermediate_size = hf_config.mamba_expand *\
hf_config.hidden_size,
use_conv_bias = hf_config.mamba_conv_bias,
use_bias = hf_config.mamba_proj_bias,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
rms_norm_eps=hf_config.rms_norm_eps,
activation=hf_config.hidden_act,
prefix=key,
)
# suppress var not used error
assert fwd_context is not None
vllm_ctx = vllm_config.compilation_config.static_forward_context
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
runner = GPUModelRunner(vllm_config, DEVICE)
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks = kv_cache_config.num_blocks
ind = np.arange(num_blocks)
np.random.shuffle(ind)
blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):]
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
# assert we are using FlashInfer
assert attn_shape[0] == num_blocks
attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]),
device=DEVICE,
fill_value=3.33)
conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]),
device=DEVICE,
fill_value=6.66)
ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]),
device=DEVICE,
fill_value=9.99)
# fill all attention blocks with constant
for layer in [layer_0, layer_1]:
vllm_ctx[layer].kv_cache[0][
blocks0, :] = attn_blocks_constant.detach().clone()
# fill all mamba blocks with constant
for layer in [layer_2, layer_3, layer_4, layer_5]:
vllm_ctx[layer].kv_cache[0][0][
blocks1, :] = conv_blocks_constant.detach().clone()
vllm_ctx[layer].kv_cache[0][1][
blocks1, :] = ssm_blocks_constant.detach().clone()
# verify attention and mamba contents are correct
for layer in [layer_0, layer_1]:
assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :],
attn_blocks_constant)
for layer in [layer_2, layer_3, layer_4, layer_5]:
assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :],
conv_blocks_constant)
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
ssm_blocks_constant)