[Bugfix] Avoid import AttentionMetadata explicitly in Mllama (#10593)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-24 02:12:20 +08:00 committed by GitHub
parent 651f6c31ac
commit 04668ebe7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 21 additions and 11 deletions

View File

@ -87,6 +87,11 @@ class BlocksparseParams:
class BlocksparseFlashAttentionBackend(AttentionBackend): class BlocksparseFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
# For attention layer compatibility
return "FLASH_ATTN"
@staticmethod @staticmethod
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
return BlocksparseFlashAttentionImpl return BlocksparseFlashAttentionImpl

View File

@ -6,7 +6,7 @@ import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -98,6 +98,7 @@ class Attention(nn.Module):
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap) blocksparse_params, logits_soft_cap)
self.backend = backend_name_to_enum(attn_backend.get_name())
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant # torch.compile works by registering the attention as one giant

View File

@ -32,9 +32,8 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
@ -828,7 +827,8 @@ class MllamaTextCrossAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
# Skip writing kv-cache for the initial profiling run. # Skip writing kv-cache for the initial profiling run.
if len(kv_cache.shape) > 1: if len(kv_cache.shape) > 1:
if isinstance(attn_metadata, FlashAttentionMetadata): if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash( torch.ops._C_cache_ops.reshape_and_cache_flash(
@ -842,7 +842,7 @@ class MllamaTextCrossAttention(nn.Module):
1.0, 1.0,
1.0, 1.0,
) )
elif isinstance(attn_metadata, XFormersMetadata): elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache( key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim) kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
@ -852,9 +852,9 @@ class MllamaTextCrossAttention(nn.Module):
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
else: else:
raise ValueError( raise ValueError(
f"Unsupported AttentionMetadata {type(attn_metadata)} " f"Unsupported Attention backend {self.attn.backend} "
f"class found. Expected the AttentionMetadata to " "enum found. Expected the Attention backend to be "
f"be either XFormersMetadata or FlashAttentionMetadata.") "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
# We have to call torch.sdpa for prefill when using a # We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a # custom cross-attention mask. Because the mask is not a

View File

@ -1,7 +1,5 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import openvino as ov
import openvino.properties.hint as hints
import torch import torch
import vllm.envs as envs import vllm.envs as envs
@ -16,6 +14,12 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
try:
import openvino as ov
import openvino.properties.hint as hints
except ImportError as e:
logger.warning("Failed to import OpenVINO with %r", e)
class OpenVinoPlatform(Platform): class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO _enum = PlatformEnum.OPENVINO

View File

@ -19,7 +19,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "flash-attn-vllm-v1" return "FLASH_ATTN_VLLM_V1"
@staticmethod @staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]: def get_impl_cls() -> Type["FlashAttentionImpl"]: