mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 14:35:27 +08:00
[V1] [Hybrid] Add new test to verify that hybrid views into KVCacheTensor are compatible (#21300)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
af376ca19d
commit
488d8a986a
@ -3,15 +3,19 @@
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.utils import GiB_bytes, update_environment_variables
|
||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||
get_kv_cache_config)
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
@ -686,3 +690,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||
|
||||
|
||||
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
'''
|
||||
The GPU model runner creates different views into the
|
||||
KVCacheTensors for the attention and mamba layers
|
||||
(via _reshape_kv_cache_tensors function). This test verifies
|
||||
that the views are compatible: writing a mamba block
|
||||
will not corrupt an attention block and vice-versa
|
||||
'''
|
||||
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': "0",
|
||||
'LOCAL_RANK': "0",
|
||||
'WORLD_SIZE': "1",
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model="ibm-granite/granite-4.0-tiny-preview",
|
||||
dtype="float16",
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=BLOCK_SIZE,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
parallel_config = ParallelConfig()
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
layer_1 = "model.layers.1.self_attn.attn"
|
||||
layer_2 = "model.layers.2.mixer"
|
||||
layer_3 = "model.layers.3.mixer"
|
||||
layer_4 = "model.layers.4.mixer"
|
||||
layer_5 = "model.layers.5.mixer"
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
fwd_context = {}
|
||||
for key in [layer_0, layer_1]:
|
||||
fwd_context[key] = Attention(
|
||||
num_heads=model_config.get_num_attention_heads(
|
||||
parallel_config),
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
scale=1.0,
|
||||
prefix=key,
|
||||
)
|
||||
for key in [layer_2, layer_3, layer_4, layer_5]:
|
||||
fwd_context[key] = MambaMixer2(
|
||||
hidden_size = hf_config.hidden_size,
|
||||
ssm_state_size = hf_config.mamba_d_state,
|
||||
conv_kernel_size = hf_config.mamba_d_conv,
|
||||
intermediate_size = hf_config.mamba_expand *\
|
||||
hf_config.hidden_size,
|
||||
use_conv_bias = hf_config.mamba_conv_bias,
|
||||
use_bias = hf_config.mamba_proj_bias,
|
||||
n_groups=hf_config.mamba_n_groups,
|
||||
num_heads=hf_config.mamba_n_heads,
|
||||
head_dim=hf_config.mamba_d_head,
|
||||
rms_norm_eps=hf_config.rms_norm_eps,
|
||||
activation=hf_config.hidden_act,
|
||||
prefix=key,
|
||||
)
|
||||
# suppress var not used error
|
||||
assert fwd_context is not None
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# random partition of blocks
|
||||
# blocks0 will be assigned to attention layers
|
||||
# blocks1 will be assigned to mamba layers
|
||||
num_blocks = kv_cache_config.num_blocks
|
||||
ind = np.arange(num_blocks)
|
||||
np.random.shuffle(ind)
|
||||
blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):]
|
||||
|
||||
attn_shape = vllm_ctx[layer_0].kv_cache[0].shape
|
||||
conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape
|
||||
ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape
|
||||
|
||||
# assert we are using FlashInfer
|
||||
assert attn_shape[0] == num_blocks
|
||||
|
||||
attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]),
|
||||
device=DEVICE,
|
||||
fill_value=3.33)
|
||||
conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]),
|
||||
device=DEVICE,
|
||||
fill_value=6.66)
|
||||
ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]),
|
||||
device=DEVICE,
|
||||
fill_value=9.99)
|
||||
|
||||
# fill all attention blocks with constant
|
||||
for layer in [layer_0, layer_1]:
|
||||
vllm_ctx[layer].kv_cache[0][
|
||||
blocks0, :] = attn_blocks_constant.detach().clone()
|
||||
|
||||
# fill all mamba blocks with constant
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
vllm_ctx[layer].kv_cache[0][0][
|
||||
blocks1, :] = conv_blocks_constant.detach().clone()
|
||||
vllm_ctx[layer].kv_cache[0][1][
|
||||
blocks1, :] = ssm_blocks_constant.detach().clone()
|
||||
|
||||
# verify attention and mamba contents are correct
|
||||
for layer in [layer_0, layer_1]:
|
||||
assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :],
|
||||
attn_blocks_constant)
|
||||
for layer in [layer_2, layer_3, layer_4, layer_5]:
|
||||
assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :],
|
||||
conv_blocks_constant)
|
||||
assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :],
|
||||
ssm_blocks_constant)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user