diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0bdf1f9820d3..6ddcbfea24ad 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -3,15 +3,19 @@ import random +import numpy as np import pytest import torch from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes +from vllm.utils import GiB_bytes, update_environment_variables from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, get_kv_cache_config) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, @@ -686,3 +690,147 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + + +def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): + ''' + The GPU model runner creates different views into the + KVCacheTensors for the attention and mamba layers + (via _reshape_kv_cache_tensors function). This test verifies + that the views are compatible: writing a mamba block + will not corrupt an attention block and vice-versa + ''' + + current_platform.seed_everything(42) + + update_environment_variables({ + 'RANK': "0", + 'LOCAL_RANK': "0", + 'WORLD_SIZE': "1", + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=1) + torch.set_default_dtype(torch.float16) + + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + ) + model_config = ModelConfig( + model="ibm-granite/granite-4.0-tiny-preview", + dtype="float16", + ) + cache_config = CacheConfig( + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + parallel_config = ParallelConfig() + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + ) + + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + layer_2 = "model.layers.2.mixer" + layer_3 = "model.layers.3.mixer" + layer_4 = "model.layers.4.mixer" + layer_5 = "model.layers.5.mixer" + + with set_current_vllm_config(vllm_config): + hf_config = vllm_config.model_config.hf_config + fwd_context = {} + for key in [layer_0, layer_1]: + fwd_context[key] = Attention( + num_heads=model_config.get_num_attention_heads( + parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + scale=1.0, + prefix=key, + ) + for key in [layer_2, layer_3, layer_4, layer_5]: + fwd_context[key] = MambaMixer2( + hidden_size = hf_config.hidden_size, + ssm_state_size = hf_config.mamba_d_state, + conv_kernel_size = hf_config.mamba_d_conv, + intermediate_size = hf_config.mamba_expand *\ + hf_config.hidden_size, + use_conv_bias = hf_config.mamba_conv_bias, + use_bias = hf_config.mamba_proj_bias, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + rms_norm_eps=hf_config.rms_norm_eps, + activation=hf_config.hidden_act, + prefix=key, + ) + # suppress var not used error + assert fwd_context is not None + vllm_ctx = vllm_config.compilation_config.static_forward_context + + with monkeypatch.context() as m: + + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + + runner = GPUModelRunner(vllm_config, DEVICE) + kv_cache_spec = runner.get_kv_cache_spec() + + available_memory = 5 * GiB_bytes + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + available_memory) + runner.initialize_kv_cache(kv_cache_config) + + # random partition of blocks + # blocks0 will be assigned to attention layers + # blocks1 will be assigned to mamba layers + num_blocks = kv_cache_config.num_blocks + ind = np.arange(num_blocks) + np.random.shuffle(ind) + blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):] + + attn_shape = vllm_ctx[layer_0].kv_cache[0].shape + conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape + ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape + + # assert we are using FlashInfer + assert attn_shape[0] == num_blocks + + attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]), + device=DEVICE, + fill_value=3.33) + conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]), + device=DEVICE, + fill_value=6.66) + ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]), + device=DEVICE, + fill_value=9.99) + + # fill all attention blocks with constant + for layer in [layer_0, layer_1]: + vllm_ctx[layer].kv_cache[0][ + blocks0, :] = attn_blocks_constant.detach().clone() + + # fill all mamba blocks with constant + for layer in [layer_2, layer_3, layer_4, layer_5]: + vllm_ctx[layer].kv_cache[0][0][ + blocks1, :] = conv_blocks_constant.detach().clone() + vllm_ctx[layer].kv_cache[0][1][ + blocks1, :] = ssm_blocks_constant.detach().clone() + + # verify attention and mamba contents are correct + for layer in [layer_0, layer_1]: + assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :], + attn_blocks_constant) + for layer in [layer_2, layer_3, layer_4, layer_5]: + assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :], + conv_blocks_constant) + assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], + ssm_blocks_constant)