mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 01:22:27 +08:00
[MM] Pass FA version in ViT Attn (#30756)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e80455ca8b
commit
e087fbc393
@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import (
|
|||||||
vit_flash_attn_wrapper,
|
vit_flash_attn_wrapper,
|
||||||
vit_torch_sdpa_wrapper,
|
vit_torch_sdpa_wrapper,
|
||||||
)
|
)
|
||||||
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.config import MultiModalConfig
|
from vllm.config import MultiModalConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp):
|
|||||||
self.attn_backend,
|
self.attn_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.is_flash_attn_backend:
|
||||||
|
assert self.flash_attn_varlen_func is not None
|
||||||
|
self._fa_version = get_flash_attn_version()
|
||||||
|
|
||||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp):
|
|||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
batch_size=bsz,
|
batch_size=bsz,
|
||||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||||
|
fa_version=self._fa_version,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -28,11 +28,15 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
max_seqlen: torch.Tensor,
|
max_seqlen: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
|
fa_version: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
kwargs = {}
|
||||||
if is_rocm_aiter:
|
if is_rocm_aiter:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
|
|
||||||
|
kwargs["fa_version"] = fa_version
|
||||||
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
output = flash_attn_varlen_func(
|
output = flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
@ -44,6 +48,7 @@ def flash_attn_maxseqlen_wrapper(
|
|||||||
max_seqlen_k=max_seqlen.item(),
|
max_seqlen_k=max_seqlen.item(),
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
causal=False,
|
causal=False,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||||
return context_layer
|
return context_layer
|
||||||
@ -57,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake(
|
|||||||
max_seqlen: torch.Tensor,
|
max_seqlen: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
|
fa_version: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(q)
|
return torch.empty_like(q)
|
||||||
|
|
||||||
@ -76,9 +82,10 @@ def vit_flash_attn_wrapper(
|
|||||||
max_seqlen: torch.Tensor,
|
max_seqlen: torch.Tensor,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
is_rocm_aiter: bool,
|
is_rocm_aiter: bool,
|
||||||
|
fa_version: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
|
||||||
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter
|
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user