mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:25:01 +08:00
[Kernel] [V1] Fix performance regression for triton unified attention (#18161)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
451da4bcbd
commit
01c22335ba
@ -56,11 +56,11 @@ def kernel_unified_attention_2d(
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,10 +12,23 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
super().__init__(runner, kv_cache_spec, block_table)
|
||||
self.aot_schedule = False
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
@ -52,8 +65,8 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user