mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:35:51 +08:00
[Bugfix] Avoid import AttentionMetadata explicitly in Mllama (#10593)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
651f6c31ac
commit
04668ebe7a
@ -87,6 +87,11 @@ class BlocksparseParams:
|
||||
|
||||
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
# For attention layer compatibility
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||
return BlocksparseFlashAttentionImpl
|
||||
|
||||
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -98,6 +98,7 @@ class Attention(nn.Module):
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap)
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
|
||||
@ -32,9 +32,8 @@ from transformers.models.mllama.processing_mllama import (
|
||||
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersMetadata
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.selector import _Backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
||||
@ -828,7 +827,8 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# Skip writing kv-cache for the initial profiling run.
|
||||
if len(kv_cache.shape) > 1:
|
||||
if isinstance(attn_metadata, FlashAttentionMetadata):
|
||||
if self.attn.backend in (_Backend.FLASH_ATTN,
|
||||
_Backend.FLASH_ATTN_VLLM_V1):
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
@ -842,7 +842,7 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
1.0,
|
||||
1.0,
|
||||
)
|
||||
elif isinstance(attn_metadata, XFormersMetadata):
|
||||
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||
@ -852,9 +852,9 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported AttentionMetadata {type(attn_metadata)} "
|
||||
f"class found. Expected the AttentionMetadata to "
|
||||
f"be either XFormersMetadata or FlashAttentionMetadata.")
|
||||
f"Unsupported Attention backend {self.attn.backend} "
|
||||
"enum found. Expected the Attention backend to be "
|
||||
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
|
||||
|
||||
# We have to call torch.sdpa for prefill when using a
|
||||
# custom cross-attention mask. Because the mask is not a
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import openvino as ov
|
||||
import openvino.properties.hint as hints
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -16,6 +14,12 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import openvino as ov
|
||||
import openvino.properties.hint as hints
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import OpenVINO with %r", e)
|
||||
|
||||
|
||||
class OpenVinoPlatform(Platform):
|
||||
_enum = PlatformEnum.OPENVINO
|
||||
|
||||
@ -19,7 +19,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "flash-attn-vllm-v1"
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user