diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index ac08b9052cd80..60e04ad9069e7 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache( # Permute the context blocks (excluding block 0 which is null) if randomize_blocks: - perm = torch.randperm( - blocks_end - 1) + 1 # Random permutation starting from block 1 + # Random permutation starting from block 1 + perm = torch.randperm(blocks_end - 1) + 1 else: - perm = torch.arange( - 1, blocks_end) # Sequential order starting from block 1 + # Sequential order starting from block 1 + perm = torch.arange(1, blocks_end) inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) - inv_perm[1:] = torch.argsort( - perm) + 1 # Add 1 to account for starting from block 1 + # Add 1 to account for starting from block 1 + inv_perm[1:] = torch.argsort(perm) + 1 kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] # Construct the right block table @@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, @pytest.mark.parametrize("batch_spec_name", [ "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium" + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" ]) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_backend_correctness(batch_spec_name: str, model: str): @@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str): """ batch_spec = BATCH_SPECS[batch_spec_name] vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens)) + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=8192) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol=rtol, atol=atol) - if not all_close: - print(f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") - print(f"[{backend_name}] output: {backend_output}") - print(f"[{backend_name}] SDPA baseline: {sdpa_output}") - assert all_close, ( f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") \ No newline at end of file diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py new file mode 100644 index 0000000000000..24070358799ef --- /dev/null +++ b/tests/v1/attention/test_mla_backends.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 MLA backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, + _Backend.TRITON_MLA_VLLM_V1 +] + +# Remove CUTLASS_MLA from the list if not using sm100 +if not torch.cuda.is_available() or torch.cuda.get_device_properties( + 0).major < 10: + BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + +torch.manual_seed(42) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.head_size, # latent dimension + dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), + device=device, + ) + return kv_cache + + +def create_and_prepopulate_kv_cache( + kv_c_contexts: list[torch.Tensor], + k_pe_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate an MLA KV cache with context data. + + Args: + kv_c_contexts: List of latent KV context tensors for each sequence + k_pe_contexts: List of key positional embedding context tensors + for each sequence + block_size: Size of each block + num_kv_heads: Number of KV heads (should be 1 for MLA) + head_size: Size of each head (latent dimension) + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + common_attn_metadata: Common attention metadata + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + MLA KV cache tensor + """ + batch_size = len(kv_c_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + start = start_block_idx * block_size + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[1:blocks_end, ...] = kv_cache[perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + layer_names: list[str], vllm_config, + device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, kv_c: torch.Tensor, + k_pe: torch.Tensor, kv_cache: torch.Tensor, + kv_lora_rank: int, qk_nope_head_dim: int, + qk_rope_head_dim: int, v_head_dim: int, + mock_kv_b_proj) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate MLA implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + ) + + # Process weights to create W_UK_T and W_UV attributes needed by MLA + act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + impl.process_weights_after_loading(act_dtype) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + num_tokens = query.shape[0] + output = torch.empty(num_tokens, + num_heads * v_head_dim, + dtype=query.dtype, + device=query.device) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + kv_c, + k_pe, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium", "large_decode", "large_prefill", + "single_decode", "single_prefill" +]) +@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_backend_correctness(dist_init, batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model, + max_model_len=max(batch_spec.seq_lens), + num_gpu_blocks=2048) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + kv_lora_rank = 512 + qk_rope_head_dim = 64 + qk_nope_head_dim = 128 + v_head_dim = 128 + total_head_size = kv_lora_rank + qk_rope_head_dim + assert kv_lora_rank + qk_rope_head_dim == head_size, \ + f"MLA dimensions don't match: {total_head_size} != {head_size}" + scale = 1.0 / (total_head_size**0.5) + + # 2. Generate data and compute SDPA reference output for MLA + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + all_sdpa_outputs = [] + kv_c_contexts, k_pe_contexts = [], [] + + # Create shared MLA weight matrices for consistency across all sequences + W_UK = torch.randn(kv_lora_rank, + num_q_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_q_heads, + v_head_dim, + dtype=dtype, + device=device) + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate MLA tensors + # Q has both nope and rope components: + # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim] + q_c = torch.randn(q_len, + num_q_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + + # KV_C (latent K/V): [s_len, kv_lora_rank] + kv_c_full = torch.randn(s_len, + kv_lora_rank, + dtype=dtype, + device=device) + + # K_PE (rope component): [s_len, 1, qk_rope_head_dim] + k_pe_full = torch.randn(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + # Determine if this is decode (single token) + # or prefill (multiple tokens) + is_decode = q_len == 1 + + # Split q into nope and rope components + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + + if is_decode: + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + else: + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, + kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, + s_len, + dtype=torch.bool, + device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + attn_mask=attn_mask, + scale=scale) + sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) + sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) + + all_sdpa_outputs.append(sdpa_out_i) + + # Inputs for vLLM MLA backends are just the new tokens + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens + all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens + + # Contextual K/V data used to populate the paged cache (MLA format) + kv_c_contexts.append(kv_c_full[:context_len]) + k_pe_contexts.append(k_pe_full[:context_len]) + + # Concatenate all sequences (no reordering needed) + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + # Create mock kv_b_proj using the same weights as reference implementation + from vllm.model_executor.layers.linear import ColumnParallelLinear + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_q_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + + # Set the mock weights to match our reference implementation + # Reshape W_UK and W_UV to match the expected kv_b_proj format + # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim] + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T) + + # Create metadata using original batch spec + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + for backend_name in BACKENDS_TO_TEST: + backend_output = run_attention_backend( + backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, + common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, + kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, + mock_kv_b_proj) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_output.shape}") + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_output.dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-1 + + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_output) / + torch.abs(sdpa_output)).item() + all_close = torch.allclose(backend_output, + sdpa_output, + rtol=rtol, + atol=atol) + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index e547e71e0cdb7..6a08cdc56f736 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -135,6 +135,12 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend", _Backend.XFORMERS_VLLM_V1: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", + _Backend.CUTLASS_MLA: + "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", + _Backend.FLASHMLA_VLLM_V1: + "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.TRITON_MLA_VLLM_V1: + "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } if backend_name not in backend_map: @@ -167,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", tensor_parallel_size: int = 1, max_model_len: int = 1024, dtype: Union[ModelDType, torch.dtype] = "auto", + num_gpu_blocks: int = 1000, block_size: int = 16, max_num_seqs: int = 256, max_num_batched_tokens: int = 8192, + enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -189,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", ) # Set cache blocks for testing # (these may be set during initialization normally) - cache_config.num_gpu_blocks = 1000 + cache_config.num_gpu_blocks = num_gpu_blocks cache_config.num_cpu_blocks = 0 parallel_config = ParallelConfig( @@ -198,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, ) device_config = DeviceConfig() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 646e4fec836bd..03028ebfe76ad 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -24,7 +24,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -82,7 +82,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -120,20 +120,20 @@ return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -155,7 +155,7 @@ curr_o, curr_lse = scaled_dot_product_attention( new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)):