diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6cbc25b4b3bff..03268beecfc0b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -313,6 +313,15 @@ steps: - pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine +- label: V1 Test attention (H100) # 10min + timeout_in_minutes: 30 + gpu: h100 + source_file_dependencies: + - vllm/v1/attention + - tests/v1/attention + commands: + - pytest -v -s v1/attention + - label: V1 Test others (CPU) # 5 mins source_file_dependencies: - vllm/ diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 351cff246d614..6659b3eb1e98f 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -428,6 +428,7 @@ def _test_backend_correctness( # [num_blocks, 2, block_size, num_kv_heads, head_size] # Select the appropriate KV cache format for each backend kv_cache_for_backend = kv_cache + reset_kv_cache_layout = False if backend_name in (_Backend.FLASHINFER, _Backend.TRITON_ATTN): kv_cache_for_backend = kv_cache.transpose(0, 1) @@ -437,20 +438,27 @@ def _test_backend_correctness( kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) ) set_kv_cache_layout("HND") + reset_kv_cache_layout = True + elif backend_name == _Backend.TRITON_ATTN: + kv_cache_for_backend = kv_cache_for_backend.contiguous() - backend_output = run_attention_backend( - backend_name, - kv_cache_spec, - ["placeholder"], - vllm_config, - device, - common_attn_metadata, - query_vllm, - key_vllm, - value_vllm, - kv_cache_for_backend, - sliding_window=sliding_window, - ) + try: + backend_output = run_attention_backend( + backend_name, + kv_cache_spec, + ["placeholder"], + vllm_config, + device, + common_attn_metadata, + query_vllm, + key_vllm, + value_vllm, + kv_cache_for_backend, + sliding_window=sliding_window, + ) + finally: + if reset_kv_cache_layout: + set_kv_cache_layout(None) # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 1a256a6e192ad..1b17532884841 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -155,7 +155,7 @@ def create_and_prepopulate_kv_cache( scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) else: # Create MLA KV cache: (num_blocks, block_size, head_size) - kv_cache = torch.empty( + kv_cache = torch.zeros( num_blocks, block_size, head_size, dtype=dtype, device=device ) kv_cache_flat = kv_cache.view(-1, head_size) @@ -212,6 +212,7 @@ def create_and_prepopulate_kv_cache( start = start_block_idx end = start + num_blocks_for_seq block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + block_table[i, num_blocks_for_seq:] = 0 start_block_idx += num_blocks_for_seq # Create a realistic slot mapping that corresponds to the block table diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e12cc581dd1a7..c16a77c093cfb 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlexAttention.""" +import math from dataclasses import dataclass import torch @@ -592,9 +593,10 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") - self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 - self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 + supports_small_blocks = is_torch_equal_or_newer("2.9.0.dev0") + self.direct_build: bool = supports_small_blocks + self.q_block_size: int = 16 if supports_small_blocks else 128 + self.kv_block_size: int = self.block_size if supports_small_blocks else 128 def build( self, @@ -867,6 +869,22 @@ def get_kernel_options( kernel_options: dict[str, int | bool] = { "FORCE_USE_FLEX_ATTENTION": True, } + + def ensure_divisible(candidate: int, block_size: int) -> int: + """Pick a kernel block size that divides the logical block.""" + if block_size <= 0: + return candidate + candidate = min(candidate, block_size) + if candidate <= 0: + return block_size + if block_size % candidate == 0: + return candidate + + candidate = math.gcd(candidate, block_size) + if candidate <= 1: + return block_size + return candidate + if vllm_is_batch_invariant(): kernel_options["BLOCK_M"] = 16 kernel_options["BLOCK_N"] = 16 @@ -877,17 +895,22 @@ def get_kernel_options( kernel_options["BLOCK_N"] = block_n return kernel_options else: - kernel_options["BLOCK_M"] = 64 - kernel_options["BLOCK_N"] = 64 - if query.dtype == torch.float32: - kernel_options["BLOCK_M"] = 32 - kernel_options["BLOCK_N"] = 32 - # if current_platform.is_cuda(): + preferred_block = 32 if query.dtype == torch.float32 else 64 + block_m_candidate = ensure_divisible(preferred_block, block_m) + block_n_candidate = ensure_divisible(preferred_block, block_n) + if torch.cuda.is_available(): device_props = torch.cuda.get_device_properties() max_shared_memory = device_props.shared_memory_per_block_optin if max_shared_memory < 144 * 1024: - kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2 - kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2 + block_m_candidate = ensure_divisible( + max(1, block_m_candidate // 2), block_m + ) + block_n_candidate = ensure_divisible( + max(1, block_n_candidate // 2), block_n + ) + + kernel_options["BLOCK_M"] = block_m_candidate + kernel_options["BLOCK_N"] = block_n_candidate return kernel_options