Add FLASHINFER_MLA to test_mla_backends and add B200 CI run (#27663)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-31 14:12:19 -04:00 committed by GitHub
parent 5e8862e9e0
commit f29aeb5a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 208 additions and 64 deletions

View File

@ -340,6 +340,16 @@ steps:
commands: commands:
- pytest -v -s v1/attention - pytest -v -s v1/attention
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200
source_file_dependencies:
- vllm/v1/attention
- tests/v1/attention
commands:
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
- pytest -v -s v1/attention
- label: V1 Test others (CPU) # 5 mins - label: V1 Test others (CPU) # 5 mins
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/

View File

@ -14,16 +14,19 @@ import torch
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
BatchSpec, BatchSpec,
create_common_attn_metadata, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.CUTLASS_MLA,
_Backend.FLASHMLA, _Backend.FLASHMLA,
_Backend.FLASH_ATTN_MLA, _Backend.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA,
_Backend.TRITON_MLA, _Backend.TRITON_MLA,
] ]
# Remove CUTLASS_MLA from the list if not using sm100 # Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10: if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA) BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported # Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]: if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA) BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
if query_len_support != QueryLenSupport.SINGLE_ONLY:
SPEC_DECODE_BACKENDS.append(backend)
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend)
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
default_size if isinstance(default_size, int) else default_size.base
)
else:
block_size = 16
BACKEND_BLOCK_SIZES[backend] = block_size
torch.manual_seed(42) torch.manual_seed(42)
@ -236,6 +268,26 @@ class MockAttentionLayer:
self._q_scale = torch.tensor(1.0, device=device) self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device) self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device) self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def forward(self, *_args, **_kwargs):
raise NotImplementedError
class MockMLAAttentionLayer(AttentionLayerBase):
"""A mock MLA attention layer for populating static_forward_context."""
def __init__(self, impl):
self.impl = impl
def get_attn_backend(self):
raise NotImplementedError
def get_kv_cache_spec(self, vllm_config):
raise NotImplementedError
def run_attention_backend( def run_attention_backend(
@ -262,13 +314,6 @@ def run_attention_backend(
# Set the current vllm config so that get_current_vllm_config() works # Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations # in the backend implementations
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Instantiate MLA implementation # Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads( num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config vllm_config.parallel_config
@ -302,6 +347,19 @@ def run_attention_backend(
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype) impl.process_weights_after_loading(act_dtype)
# Populate static_forward_context with mock attention layers
for layer_name in layer_names:
vllm_config.compilation_config.static_forward_context[layer_name] = (
MockMLAAttentionLayer(impl)
)
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Create mock layer and output buffer # Create mock layer and output buffer
mock_layer = MockAttentionLayer(device) mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0] num_tokens = query.shape[0]
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache. simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output. 5. Comparing the vLLM backend's output to the ground-truth SDPA output.
""" """
from vllm.v1.attention.backends.mla.common import QueryLenSupport
batch_spec = BATCH_SPECS[batch_spec_name] batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode") is_spec_decode_test = batch_spec_name.startswith("spec_decode")
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA} unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
default_block_size = unique_block_sizes[0]
block_size = 16
required_blocks = sum( required_blocks = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens (seq_len + default_block_size - 1) // default_block_size
for seq_len in batch_spec.seq_lens
) )
# Add 1 for null block at index 0, and some buffer # Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100 num_gpu_blocks = required_blocks + 1 + 100
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
model_name=model, model_name=model,
max_model_len=max(batch_spec.seq_lens), max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
block_size=block_size, block_size=default_block_size,
) )
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold # For spec decode tests, add a speculative_config to set the reorder_batch_threshold
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device = torch.device("cuda:0") device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# 1. Setup # 1. Setup
batch_size = batch_spec.batch_size batch_size = batch_spec.batch_size
seq_lens = batch_spec.seq_lens seq_lens = batch_spec.seq_lens
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
) )
head_size = vllm_config.model_config.get_head_size() head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
kv_lora_rank = 512 kv_lora_rank = 512
qk_rope_head_dim = 64 qk_rope_head_dim = 64
qk_nope_head_dim = 128 qk_nope_head_dim = 128
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
) )
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False) mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
# Create metadata using original batch spec # 3. Create metadata and KV caches for each block size
common_attn_metadata = create_common_attn_metadata( # Group backends by block size and test each group
batch_spec, vllm_config.cache_config.block_size, device metadata_per_block_size = {}
) kv_cache_per_block_size = {}
# 3. Simulate Paged KV Cache and a realistic slot_mapping for block_size in unique_block_sizes:
kv_cache = create_and_prepopulate_kv_cache( # Create metadata for this block size
kv_c_contexts=kv_c_contexts, common_attn_metadata = create_common_attn_metadata(
k_pe_contexts=k_pe_contexts, batch_spec, block_size, device
block_size=block_size, )
head_size=head_size,
dtype=dtype, # Pad block table to meet requirement:
device=device, # block_num % (128 / block_size) == 0
num_blocks=vllm_config.cache_config.num_gpu_blocks, required_divisor = int(128 / block_size)
common_attn_metadata=common_attn_metadata, current_block_num = common_attn_metadata.block_table_tensor.shape[1]
randomize_blocks=True, if current_block_num % required_divisor != 0:
) # Pad to next multiple of required_divisor
padded_block_num = (
(current_block_num + required_divisor - 1) // required_divisor
) * required_divisor
padding_cols = padded_block_num - current_block_num
padding = torch.zeros(
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
dtype=torch.int32,
device=device,
)
common_attn_metadata.block_table_tensor = torch.cat(
[common_attn_metadata.block_table_tensor, padding], dim=1
)
metadata_per_block_size[block_size] = common_attn_metadata
# Create KV cache for this block size
required_blocks_for_size = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
)
num_blocks_for_size = required_blocks_for_size + 1 + 100
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
)
kv_cache_per_block_size[block_size] = kv_cache
# 4. Run vLLM backends and compare # 4. Run vLLM backends and compare
failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST): for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
# Skip backends that don't support spec decode for spec decode tests # Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in spec_decode_backends: if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue continue
# Get the appropriate block_size, metadata, and cache for this backend
block_size = BACKEND_BLOCK_SIZES[backend_name]
common_attn_metadata = metadata_per_block_size[block_size]
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
)
backend_output = run_attention_backend( backend_output = run_attention_backend(
backend_name, backend_name,
kv_cache_spec, backend_kv_cache_spec,
["placeholder"], ["placeholder"],
vllm_config, vllm_config,
device, device,
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
expected_output = sdpa_outputs[backend_name] expected_output = sdpa_outputs[backend_name]
# Check shape and dtype consistency # Check shape and dtype consistency
assert backend_output.shape == expected_output.shape, ( try:
f"[{backend_name}] shape {backend_output.shape} != " assert backend_output.shape == expected_output.shape, (
f"SDPA shape {expected_output.shape}" f"[{backend_name}] shape {backend_output.shape} != "
) f"SDPA shape {expected_output.shape}"
assert backend_output.dtype == expected_output.dtype, ( )
f"[{backend_name}] dtype {backend_output.dtype} != " assert backend_output.dtype == expected_output.dtype, (
f"SDPA dtype {expected_output.dtype}" f"[{backend_name}] dtype {backend_output.dtype} != "
) f"SDPA dtype {expected_output.dtype}"
)
assert torch.isfinite(backend_output).all(), ( assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values" f"[{backend_name}] produced non-finite values"
) )
# Check numerical similarity # Check numerical similarity
rtol = 1e-2 rtol = 1e-2
atol = 5e-1 atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - expected_output)).item() max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max( max_rel_diff = torch.max(
torch.abs(backend_output - expected_output) / torch.abs(expected_output) torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item() ).item()
all_close = torch.allclose( all_close = torch.allclose(
backend_output, expected_output, rtol=rtol, atol=atol backend_output, expected_output, rtol=rtol, atol=atol
) )
assert all_close, ( assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. " f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})" f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
) )
except AssertionError as e:
failures.append(str(e))
# Report all failures at once
if failures:
# Create a summary for the single-line failure message
backend_names = []
for f in failures:
if "[_Backend." in f:
backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name)
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
detailed_msg = "\n".join(failures)
pytest.fail(f"{summary}\n{detailed_msg}")

View File

@ -285,7 +285,17 @@ full_cg_backend_configs = {
name="CutlassMLA", name="CutlassMLA",
env_vars={ env_vars={
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed },
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0),
),
# FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig(
name="FlashInferMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
}, },
comp_config={ comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",

View File

@ -6,7 +6,7 @@ from typing import ClassVar
import torch import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import ( from vllm.v1.attention.backends.mla.common import (
MLACommonBackend, MLACommonBackend,
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder return FlashInferMLAMetadataBuilder
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [32, 64]
g_fi_workspace = torch.zeros( g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE, FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,