diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c16a77c093cfb..928252636d583 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -896,6 +896,8 @@ def get_kernel_options( return kernel_options else: preferred_block = 32 if query.dtype == torch.float32 else 64 + block_lower_bound = 16 + block_m_candidate = ensure_divisible(preferred_block, block_m) block_n_candidate = ensure_divisible(preferred_block, block_n) @@ -910,6 +912,9 @@ def get_kernel_options( max(1, block_n_candidate // 2), block_n ) + block_m_candidate = max(block_m_candidate, block_lower_bound) + block_n_candidate = max(block_n_candidate, block_lower_bound) + kernel_options["BLOCK_M"] = block_m_candidate kernel_options["BLOCK_N"] = block_n_candidate