Add FLASHINFER_MLA to test_mla_backends and add B200 CI run (#27663)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-31 14:12:19 -04:00 committed by GitHub
parent 5e8862e9e0
commit f29aeb5a25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 208 additions and 64 deletions

View File

@ -340,6 +340,16 @@ steps:
commands:
- pytest -v -s v1/attention
- label: V1 Test attention (B200) # 10min
timeout_in_minutes: 30
gpu: b200
source_file_dependencies:
- vllm/v1/attention
- tests/v1/attention
commands:
- export VLLM_DISABLE_FLASHINFER_PREFILL=1 # TODO: FI prefill is bugged and causes incorrectness, fix this
- pytest -v -s v1/attention
- label: V1 Test others (CPU) # 5 mins
source_file_dependencies:
- vllm/

View File

@ -14,16 +14,19 @@ import torch
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
try_get_attention_backend,
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -31,17 +34,46 @@ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA,
_Backend.FLASHMLA,
_Backend.FLASH_ATTN_MLA,
_Backend.FLASHINFER_MLA,
_Backend.TRITON_MLA,
]
# Remove CUTLASS_MLA from the list if not using sm100
# Remove sm100 backends from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_MLA)
# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
BACKENDS_TO_TEST.remove(_Backend.FLASH_ATTN_MLA)
# Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]:
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
builder_cls, _ = try_get_attention_backend(backend)
query_len_support = getattr(
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
)
if query_len_support != QueryLenSupport.SINGLE_ONLY:
SPEC_DECODE_BACKENDS.append(backend)
BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
backend_class_str = backend_to_class_str(backend)
backend_class = resolve_obj_by_qualname(backend_class_str)
supported_sizes = backend_class.get_supported_kernel_block_size()
if supported_sizes:
default_size = supported_sizes[0]
block_size = (
default_size if isinstance(default_size, int) else default_size.base
)
else:
block_size = 16
BACKEND_BLOCK_SIZES[backend] = block_size
torch.manual_seed(42)
@ -236,6 +268,26 @@ class MockAttentionLayer:
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
self._prob_scale = torch.tensor(1.0, device=device)
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
def forward(self, *_args, **_kwargs):
raise NotImplementedError
class MockMLAAttentionLayer(AttentionLayerBase):
"""A mock MLA attention layer for populating static_forward_context."""
def __init__(self, impl):
self.impl = impl
def get_attn_backend(self):
raise NotImplementedError
def get_kv_cache_spec(self, vllm_config):
raise NotImplementedError
def run_attention_backend(
@ -262,13 +314,6 @@ def run_attention_backend(
# Set the current vllm config so that get_current_vllm_config() works
# in the backend implementations
with set_current_vllm_config(vllm_config):
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
@ -302,6 +347,19 @@ def run_attention_backend(
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype)
# Populate static_forward_context with mock attention layers
for layer_name in layer_names:
vllm_config.compilation_config.static_forward_context[layer_name] = (
MockMLAAttentionLayer(impl)
)
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Create mock layer and output buffer
mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0]
@ -353,15 +411,14 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
from vllm.v1.attention.backends.mla.common import QueryLenSupport
batch_spec = BATCH_SPECS[batch_spec_name]
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
block_size = 16
unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES.values()))
default_block_size = unique_block_sizes[0]
required_blocks = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
(seq_len + default_block_size - 1) // default_block_size
for seq_len in batch_spec.seq_lens
)
# Add 1 for null block at index 0, and some buffer
num_gpu_blocks = required_blocks + 1 + 100
@ -370,7 +427,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
model_name=model,
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=num_gpu_blocks,
block_size=block_size,
block_size=default_block_size,
)
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
@ -388,8 +445,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# 1. Setup
batch_size = batch_spec.batch_size
seq_lens = batch_spec.seq_lens
@ -399,7 +454,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
kv_lora_rank = 512
qk_rope_head_dim = 64
qk_nope_head_dim = 128
@ -598,33 +652,83 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
# Create metadata using original batch spec
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
# 3. Create metadata and KV caches for each block size
# Group backends by block size and test each group
metadata_per_block_size = {}
kv_cache_per_block_size = {}
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
)
for block_size in unique_block_sizes:
# Create metadata for this block size
common_attn_metadata = create_common_attn_metadata(
batch_spec, block_size, device
)
# Pad block table to meet requirement:
# block_num % (128 / block_size) == 0
required_divisor = int(128 / block_size)
current_block_num = common_attn_metadata.block_table_tensor.shape[1]
if current_block_num % required_divisor != 0:
# Pad to next multiple of required_divisor
padded_block_num = (
(current_block_num + required_divisor - 1) // required_divisor
) * required_divisor
padding_cols = padded_block_num - current_block_num
padding = torch.zeros(
(common_attn_metadata.block_table_tensor.shape[0], padding_cols),
dtype=torch.int32,
device=device,
)
common_attn_metadata.block_table_tensor = torch.cat(
[common_attn_metadata.block_table_tensor, padding], dim=1
)
metadata_per_block_size[block_size] = common_attn_metadata
# Create KV cache for this block size
required_blocks_for_size = sum(
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
)
num_blocks_for_size = required_blocks_for_size + 1 + 100
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=num_blocks_for_size,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True,
)
kv_cache_per_block_size[block_size] = kv_cache
# 4. Run vLLM backends and compare
failures = []
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
# Skip backends that don't support spec decode for spec decode tests
if is_spec_decode_test and backend_name not in spec_decode_backends:
if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
continue
# Get the appropriate block_size, metadata, and cache for this backend
block_size = BACKEND_BLOCK_SIZES[backend_name]
common_attn_metadata = metadata_per_block_size[block_size]
kv_cache = kv_cache_per_block_size[block_size]
# Create kv_cache_spec with the correct block_size for this backend
backend_kv_cache_spec = FullAttentionSpec(
block_size=block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
sliding_window=vllm_config.model_config.get_sliding_window(),
)
backend_output = run_attention_backend(
backend_name,
kv_cache_spec,
backend_kv_cache_spec,
["placeholder"],
vllm_config,
device,
@ -644,32 +748,48 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
expected_output = sdpa_outputs[backend_name]
# Check shape and dtype consistency
assert backend_output.shape == expected_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {expected_output.shape}"
)
assert backend_output.dtype == expected_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {expected_output.dtype}"
)
try:
assert backend_output.shape == expected_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {expected_output.shape}"
)
assert backend_output.dtype == expected_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {expected_output.dtype}"
)
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values"
)
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values"
)
# Check numerical similarity
rtol = 1e-2
atol = 5e-1
# Check numerical similarity
rtol = 1e-2
atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item()
all_close = torch.allclose(
backend_output, expected_output, rtol=rtol, atol=atol
)
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
).item()
all_close = torch.allclose(
backend_output, expected_output, rtol=rtol, atol=atol
)
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
)
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
)
except AssertionError as e:
failures.append(str(e))
# Report all failures at once
if failures:
# Create a summary for the single-line failure message
backend_names = []
for f in failures:
if "[_Backend." in f:
backend_name = f.split("[")[1].split("]")[0]
backend_names.append(backend_name)
summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
detailed_msg = "\n".join(failures)
pytest.fail(f"{summary}\n{detailed_msg}")

View File

@ -285,7 +285,17 @@ full_cg_backend_configs = {
name="CutlassMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0),
),
# FlashInfer MLA on Blackwell
"FlashInferMLA": BackendConfig(
name="FlashInferMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHINFER_MLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",

View File

@ -6,7 +6,7 @@ from typing import ClassVar
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
@ -40,6 +40,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
@classmethod
def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
return [32, 64]
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,