[Bugfix][CI] Fix v1 attention backend tests and add CI coverage (#26597)

Signed-off-by: Mohammad Miadh Angkad <MAngkad.BSDSBA2027@aim.edu>
Signed-off-by: Mohammad Miadh Angkad <mangkad.bsdsba2027@aim.edu>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Mohammad Miadh Angkad 2025-10-28 23:42:05 +08:00 committed by GitHub
parent 02af36df36
commit a8c02fb5bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 25 deletions

View File

@ -313,6 +313,15 @@ steps:
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: V1 Test attention (H100) # 10min
timeout_in_minutes: 30
gpu: h100
source_file_dependencies:
- vllm/v1/attention
- tests/v1/attention
commands:
- pytest -v -s v1/attention
- label: V1 Test others (CPU) # 5 mins
source_file_dependencies:
- vllm/

View File

@ -428,6 +428,7 @@ def _test_backend_correctness(
# [num_blocks, 2, block_size, num_kv_heads, head_size]
# Select the appropriate KV cache format for each backend
kv_cache_for_backend = kv_cache
reset_kv_cache_layout = False
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
kv_cache_for_backend = kv_cache.transpose(0, 1)
@ -437,20 +438,27 @@ def _test_backend_correctness(
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
)
set_kv_cache_layout("HND")
reset_kv_cache_layout = True
elif backend_name == _Backend.TRITON_ATTN:
kv_cache_for_backend = kv_cache_for_backend.contiguous()
backend_output = run_attention_backend(
backend_name,
kv_cache_spec,
["placeholder"],
vllm_config,
device,
common_attn_metadata,
query_vllm,
key_vllm,
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
)
try:
backend_output = run_attention_backend(
backend_name,
kv_cache_spec,
["placeholder"],
vllm_config,
device,
common_attn_metadata,
query_vllm,
key_vllm,
value_vllm,
kv_cache_for_backend,
sliding_window=sliding_window,
)
finally:
if reset_kv_cache_layout:
set_kv_cache_layout(None)
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (

View File

@ -155,7 +155,7 @@ def create_and_prepopulate_kv_cache(
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
else:
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(
kv_cache = torch.zeros(
num_blocks, block_size, head_size, dtype=dtype, device=device
)
kv_cache_flat = kv_cache.view(-1, head_size)
@ -212,6 +212,7 @@ def create_and_prepopulate_kv_cache(
start = start_block_idx
end = start + num_blocks_for_seq
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
block_table[i, num_blocks_for_seq:] = 0
start_block_idx += num_blocks_for_seq
# Create a realistic slot mapping that corresponds to the block table

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlexAttention."""
import math
from dataclasses import dataclass
import torch
@ -592,9 +593,10 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0")
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
supports_small_blocks = is_torch_equal_or_newer("2.9.0.dev0")
self.direct_build: bool = supports_small_blocks
self.q_block_size: int = 16 if supports_small_blocks else 128
self.kv_block_size: int = self.block_size if supports_small_blocks else 128
def build(
self,
@ -867,6 +869,22 @@ def get_kernel_options(
kernel_options: dict[str, int | bool] = {
"FORCE_USE_FLEX_ATTENTION": True,
}
def ensure_divisible(candidate: int, block_size: int) -> int:
"""Pick a kernel block size that divides the logical block."""
if block_size <= 0:
return candidate
candidate = min(candidate, block_size)
if candidate <= 0:
return block_size
if block_size % candidate == 0:
return candidate
candidate = math.gcd(candidate, block_size)
if candidate <= 1:
return block_size
return candidate
if vllm_is_batch_invariant():
kernel_options["BLOCK_M"] = 16
kernel_options["BLOCK_N"] = 16
@ -877,17 +895,22 @@ def get_kernel_options(
kernel_options["BLOCK_N"] = block_n
return kernel_options
else:
kernel_options["BLOCK_M"] = 64
kernel_options["BLOCK_N"] = 64
if query.dtype == torch.float32:
kernel_options["BLOCK_M"] = 32
kernel_options["BLOCK_N"] = 32
# if current_platform.is_cuda():
preferred_block = 32 if query.dtype == torch.float32 else 64
block_m_candidate = ensure_divisible(preferred_block, block_m)
block_n_candidate = ensure_divisible(preferred_block, block_n)
if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
if max_shared_memory < 144 * 1024:
kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2
kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2
block_m_candidate = ensure_divisible(
max(1, block_m_candidate // 2), block_m
)
block_n_candidate = ensure_divisible(
max(1, block_n_candidate // 2), block_n
)
kernel_options["BLOCK_M"] = block_m_candidate
kernel_options["BLOCK_N"] = block_n_candidate
return kernel_options