mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 01:35:24 +08:00
[Bugfix][CI] Fix v1 attention backend tests and add CI coverage (#26597)
Signed-off-by: Mohammad Miadh Angkad <MAngkad.BSDSBA2027@aim.edu> Signed-off-by: Mohammad Miadh Angkad <mangkad.bsdsba2027@aim.edu> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
02af36df36
commit
a8c02fb5bf
@ -313,6 +313,15 @@ steps:
|
|||||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||||
|
|
||||||
|
- label: V1 Test attention (H100) # 10min
|
||||||
|
timeout_in_minutes: 30
|
||||||
|
gpu: h100
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/v1/attention
|
||||||
|
- tests/v1/attention
|
||||||
|
commands:
|
||||||
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
- label: V1 Test others (CPU) # 5 mins
|
- label: V1 Test others (CPU) # 5 mins
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
|
|||||||
@ -428,6 +428,7 @@ def _test_backend_correctness(
|
|||||||
# [num_blocks, 2, block_size, num_kv_heads, head_size]
|
# [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||||
# Select the appropriate KV cache format for each backend
|
# Select the appropriate KV cache format for each backend
|
||||||
kv_cache_for_backend = kv_cache
|
kv_cache_for_backend = kv_cache
|
||||||
|
reset_kv_cache_layout = False
|
||||||
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
|
if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN):
|
||||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||||
|
|
||||||
@ -437,7 +438,11 @@ def _test_backend_correctness(
|
|||||||
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
||||||
)
|
)
|
||||||
set_kv_cache_layout("HND")
|
set_kv_cache_layout("HND")
|
||||||
|
reset_kv_cache_layout = True
|
||||||
|
elif backend_name == _Backend.TRITON_ATTN:
|
||||||
|
kv_cache_for_backend = kv_cache_for_backend.contiguous()
|
||||||
|
|
||||||
|
try:
|
||||||
backend_output = run_attention_backend(
|
backend_output = run_attention_backend(
|
||||||
backend_name,
|
backend_name,
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
@ -451,6 +456,9 @@ def _test_backend_correctness(
|
|||||||
kv_cache_for_backend,
|
kv_cache_for_backend,
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
if reset_kv_cache_layout:
|
||||||
|
set_kv_cache_layout(None)
|
||||||
|
|
||||||
# Check shape and dtype consistency
|
# Check shape and dtype consistency
|
||||||
assert backend_output.shape == sdpa_output.shape, (
|
assert backend_output.shape == sdpa_output.shape, (
|
||||||
|
|||||||
@ -155,7 +155,7 @@ def create_and_prepopulate_kv_cache(
|
|||||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||||
kv_cache = torch.empty(
|
kv_cache = torch.zeros(
|
||||||
num_blocks, block_size, head_size, dtype=dtype, device=device
|
num_blocks, block_size, head_size, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||||
@ -212,6 +212,7 @@ def create_and_prepopulate_kv_cache(
|
|||||||
start = start_block_idx
|
start = start_block_idx
|
||||||
end = start + num_blocks_for_seq
|
end = start + num_blocks_for_seq
|
||||||
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
|
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
|
||||||
|
block_table[i, num_blocks_for_seq:] = 0
|
||||||
start_block_idx += num_blocks_for_seq
|
start_block_idx += num_blocks_for_seq
|
||||||
|
|
||||||
# Create a realistic slot mapping that corresponds to the block table
|
# Create a realistic slot mapping that corresponds to the block table
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention layer with FlexAttention."""
|
"""Attention layer with FlexAttention."""
|
||||||
|
|
||||||
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -592,9 +593,10 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
|||||||
self.headdim = self.model_config.get_head_size()
|
self.headdim = self.model_config.get_head_size()
|
||||||
self.block_size = kv_cache_spec.block_size
|
self.block_size = kv_cache_spec.block_size
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0")
|
supports_small_blocks = is_torch_equal_or_newer("2.9.0.dev0")
|
||||||
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
self.direct_build: bool = supports_small_blocks
|
||||||
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
self.q_block_size: int = 16 if supports_small_blocks else 128
|
||||||
|
self.kv_block_size: int = self.block_size if supports_small_blocks else 128
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
@ -867,6 +869,22 @@ def get_kernel_options(
|
|||||||
kernel_options: dict[str, int | bool] = {
|
kernel_options: dict[str, int | bool] = {
|
||||||
"FORCE_USE_FLEX_ATTENTION": True,
|
"FORCE_USE_FLEX_ATTENTION": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def ensure_divisible(candidate: int, block_size: int) -> int:
|
||||||
|
"""Pick a kernel block size that divides the logical block."""
|
||||||
|
if block_size <= 0:
|
||||||
|
return candidate
|
||||||
|
candidate = min(candidate, block_size)
|
||||||
|
if candidate <= 0:
|
||||||
|
return block_size
|
||||||
|
if block_size % candidate == 0:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
candidate = math.gcd(candidate, block_size)
|
||||||
|
if candidate <= 1:
|
||||||
|
return block_size
|
||||||
|
return candidate
|
||||||
|
|
||||||
if vllm_is_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
kernel_options["BLOCK_M"] = 16
|
kernel_options["BLOCK_M"] = 16
|
||||||
kernel_options["BLOCK_N"] = 16
|
kernel_options["BLOCK_N"] = 16
|
||||||
@ -877,17 +895,22 @@ def get_kernel_options(
|
|||||||
kernel_options["BLOCK_N"] = block_n
|
kernel_options["BLOCK_N"] = block_n
|
||||||
return kernel_options
|
return kernel_options
|
||||||
else:
|
else:
|
||||||
kernel_options["BLOCK_M"] = 64
|
preferred_block = 32 if query.dtype == torch.float32 else 64
|
||||||
kernel_options["BLOCK_N"] = 64
|
block_m_candidate = ensure_divisible(preferred_block, block_m)
|
||||||
if query.dtype == torch.float32:
|
block_n_candidate = ensure_divisible(preferred_block, block_n)
|
||||||
kernel_options["BLOCK_M"] = 32
|
|
||||||
kernel_options["BLOCK_N"] = 32
|
|
||||||
# if current_platform.is_cuda():
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device_props = torch.cuda.get_device_properties()
|
device_props = torch.cuda.get_device_properties()
|
||||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||||
if max_shared_memory < 144 * 1024:
|
if max_shared_memory < 144 * 1024:
|
||||||
kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2
|
block_m_candidate = ensure_divisible(
|
||||||
kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2
|
max(1, block_m_candidate // 2), block_m
|
||||||
|
)
|
||||||
|
block_n_candidate = ensure_divisible(
|
||||||
|
max(1, block_n_candidate // 2), block_n
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_options["BLOCK_M"] = block_m_candidate
|
||||||
|
kernel_options["BLOCK_N"] = block_n_candidate
|
||||||
|
|
||||||
return kernel_options
|
return kernel_options
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user