diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3bd5bd87fe6f..a020b0d276be 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -340,6 +340,16 @@ steps: commands: - pytest -v -s v1/attention +- label: V1 Test attention (B200) # 10min + timeout_in_minutes: 30 + gpu: b200 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this + - pytest -v -s v1/attention + - label: V1 Test others (CPU) # 5 mins source_file_dependencies: - vllm/ diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 1b1753288484..cda4fb11c096 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -14,16 +14,19 @@ import torch from tests.v1.attention.utils import ( BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, create_vllm_config, try_get_attention_backend, ) from vllm import _custom_ops as ops -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _Backend, backend_to_class_str from vllm.attention.ops.flashmla import is_flashmla_dense_supported +from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.config.vllm import set_current_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.v1.attention.backends.mla.common import QueryLenSupport from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [ _Backend.CUTLASS_MLA, _Backend.FLASHMLA, _Backend.FLASH_ATTN_MLA, + _Backend.FLASHINFER_MLA, _Backend.TRITON_MLA, ] -# Remove CUTLASS_MLA from the list if not using sm100 +# Remove sm100 backends from the list if not using sm100 if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) + BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA) + +# Remove FLASH_ATTN_MLA from the list if not supported +if not flash_attn_supports_mla(): + BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA) # Remove FLASHMLA from the list if not supported if not is_flashmla_dense_supported()[0]: BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) +SPEC_DECODE_BACKENDS = [] +for backend in BACKENDS_TO_TEST: + builder_cls, _ = try_get_attention_backend(backend) + query_len_support = getattr( + builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY + ) + if query_len_support != QueryLenSupport.SINGLE_ONLY: + SPEC_DECODE_BACKENDS.append(backend) + +BACKEND_BLOCK_SIZES = {} +for backend in BACKENDS_TO_TEST: + backend_class_str = backend_to_class_str(backend) + backend_class = resolve_obj_by_qualname(backend_class_str) + supported_sizes = backend_class.get_supported_kernel_block_size() + if supported_sizes: + default_size = supported_sizes[0] + block_size = ( + default_size if isinstance(default_size, int) else default_size.base + ) + else: + block_size = 16 + BACKEND_BLOCK_SIZES[backend] = block_size + torch.manual_seed(42) @@ -236,6 +268,26 @@ class MockAttentionLayer: self._q_scale = torch.tensor(1.0, device=device) self._k_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device) + self._prob_scale = torch.tensor(1.0, device=device) + self._q_scale_float = 1.0 + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + def forward(self, *_args, **_kwargs): + raise NotImplementedError + + +class MockMLAAttentionLayer(AttentionLayerBase): + """A mock MLA attention layer for populating static_forward_context.""" + + def __init__(self, impl): + self.impl = impl + + def get_attn_backend(self): + raise NotImplementedError + + def get_kv_cache_spec(self, vllm_config): + raise NotImplementedError def run_attention_backend( @@ -262,13 +314,6 @@ def run_attention_backend( # Set the current vllm config so that get_current_vllm_config() works # in the backend implementations with set_current_vllm_config(vllm_config): - # Build metadata - builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) - attn_metadata = builder.build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) - # Instantiate MLA implementation num_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config @@ -302,6 +347,19 @@ def run_attention_backend( act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) impl.process_weights_after_loading(act_dtype) + # Populate static_forward_context with mock attention layers + for layer_name in layer_names: + vllm_config.compilation_config.static_forward_context[layer_name] = ( + MockMLAAttentionLayer(impl) + ) + + # Build metadata + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + # Create mock layer and output buffer mock_layer = MockAttentionLayer(device) num_tokens = query.shape[0] @@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): simulated paged KV cache. 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ - from vllm.v1.attention.backends.mla.common import QueryLenSupport batch_spec = BATCH_SPECS[batch_spec_name] is_spec_decode_test = batch_spec_name.startswith("spec_decode") - spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} - - block_size = 16 + unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values())) + default_block_size = unique_block_sizes[0] required_blocks = sum( - (seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens + (seq_len + default_block_size - 1) // default_block_size + for seq_len in batch_spec.seq_lens ) # Add 1 for null block at index 0, and some buffer num_gpu_blocks = required_blocks + 1 + 100 @@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=num_gpu_blocks, - block_size=block_size, + block_size=default_block_size, ) # For spec decode tests, add a speculative_config to set the reorder_batch_threshold @@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): device = torch.device("cuda:0") - kv_cache_spec = create_standard_kv_cache_spec(vllm_config) - # 1. Setup batch_size = batch_spec.batch_size seq_lens = batch_spec.seq_lens @@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ) head_size = vllm_config.model_config.get_head_size() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) - block_size = vllm_config.cache_config.block_size kv_lora_rank = 512 qk_rope_head_dim = 64 qk_nope_head_dim = 128 @@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False) - # Create metadata using original batch spec - common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device - ) + # 3. Create metadata and KV caches for each block size + # Group backends by block size and test each group + metadata_per_block_size = {} + kv_cache_per_block_size = {} - # 3. Simulate Paged KV Cache and a realistic slot_mapping - kv_cache = create_and_prepopulate_kv_cache( - kv_c_contexts=kv_c_contexts, - k_pe_contexts=k_pe_contexts, - block_size=block_size, - head_size=head_size, - dtype=dtype, - device=device, - num_blocks=vllm_config.cache_config.num_gpu_blocks, - common_attn_metadata=common_attn_metadata, - randomize_blocks=True, - ) + for block_size in unique_block_sizes: + # Create metadata for this block size + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size, device + ) + + # Pad block table to meet requirement: + # block_num % (128 / block_size) == 0 + required_divisor = int(128 / block_size) + current_block_num = common_attn_metadata.block_table_tensor.shape[1] + if current_block_num % required_divisor != 0: + # Pad to next multiple of required_divisor + padded_block_num = ( + (current_block_num + required_divisor - 1) // required_divisor + ) * required_divisor + padding_cols = padded_block_num - current_block_num + padding = torch.zeros( + (common_attn_metadata.block_table_tensor.shape[0], padding_cols), + dtype=torch.int32, + device=device, + ) + common_attn_metadata.block_table_tensor = torch.cat( + [common_attn_metadata.block_table_tensor, padding], dim=1 + ) + + metadata_per_block_size[block_size] = common_attn_metadata + + # Create KV cache for this block size + required_blocks_for_size = sum( + (seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens + ) + num_blocks_for_size = required_blocks_for_size + 1 + 100 + + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=num_blocks_for_size, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True, + ) + kv_cache_per_block_size[block_size] = kv_cache # 4. Run vLLM backends and compare + failures = [] for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): # Skip backends that don't support spec decode for spec decode tests - if is_spec_decode_test and backend_name not in spec_decode_backends: + if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS: continue + # Get the appropriate block_size, metadata, and cache for this backend + block_size = BACKEND_BLOCK_SIZES[backend_name] + common_attn_metadata = metadata_per_block_size[block_size] + kv_cache = kv_cache_per_block_size[block_size] + + # Create kv_cache_spec with the correct block_size for this backend + backend_kv_cache_spec = FullAttentionSpec( + block_size=block_size, + num_kv_heads=vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ), + head_size=vllm_config.model_config.get_head_size(), + dtype=vllm_config.model_config.dtype, + sliding_window=vllm_config.model_config.get_sliding_window(), + ) + backend_output = run_attention_backend( backend_name, - kv_cache_spec, + backend_kv_cache_spec, ["placeholder"], vllm_config, device, @@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): expected_output = sdpa_outputs[backend_name] # Check shape and dtype consistency - assert backend_output.shape == expected_output.shape, ( - f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {expected_output.shape}" - ) - assert backend_output.dtype == expected_output.dtype, ( - f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {expected_output.dtype}" - ) + try: + assert backend_output.shape == expected_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {expected_output.shape}" + ) + assert backend_output.dtype == expected_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {expected_output.dtype}" + ) - assert torch.isfinite(backend_output).all(), ( - f"[{backend_name}] produced non-finite values" - ) + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values" + ) - # Check numerical similarity - rtol = 1e-2 - atol = 5e-1 + # Check numerical similarity + rtol = 1e-2 + atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - expected_output)).item() - max_rel_diff = torch.max( - torch.abs(backend_output - expected_output) / torch.abs(expected_output) - ).item() - all_close = torch.allclose( - backend_output, expected_output, rtol=rtol, atol=atol - ) + max_diff = torch.max(torch.abs(backend_output - expected_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - expected_output) / torch.abs(expected_output) + ).item() + all_close = torch.allclose( + backend_output, expected_output, rtol=rtol, atol=atol + ) - assert all_close, ( - f"[{backend_name}] output differs from SDPA baseline. " - f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})" - ) + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})" + ) + except AssertionError as e: + failures.append(str(e)) + + # Report all failures at once + if failures: + # Create a summary for the single-line failure message + backend_names = [] + for f in failures: + if "[_Backend." in f: + backend_name = f.split("[")[1].split("]")[0] + backend_names.append(backend_name) + + summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}" + detailed_msg = "\n".join(failures) + pytest.fail(f"{summary}\n{detailed_msg}") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 15ed7bdc835b..b166d9d4ff68 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -285,7 +285,17 @@ full_cg_backend_configs = { name="CutlassMLA", env_vars={ "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", - "FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(10, 0), + ), + # FlashInfer MLA on Blackwell + "FlashInferMLA": BackendConfig( + name="FlashInferMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA", }, comp_config={ "cudagraph_mode": "FULL_AND_PIECEWISE", diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 44807c39cad3..ebbcfd0eaa2f 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -6,7 +6,7 @@ from typing import ClassVar import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla -from vllm.attention.backends.abstract import AttentionLayer, AttentionType +from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, @@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend): def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: return FlashInferMLAMetadataBuilder + @classmethod + def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: + return [32, 64] + g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,