diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 94002e36db2b..9e54c3b40c54 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -87,6 +87,11 @@ class BlocksparseParams: class BlocksparseFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + # For attention layer compatibility + return "FLASH_ATTN" + @staticmethod def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: return BlocksparseFlashAttentionImpl diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index cb4dedf481c7..1bb335909484 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -6,7 +6,7 @@ import torch.nn as nn import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType -from vllm.attention.selector import get_attn_backend +from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( @@ -98,6 +98,7 @@ class Attention(nn.Module): self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) + self.backend = backend_name_to_enum(attn_backend.get_name()) # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 41f62b37f3bd..9e6634a9a757 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -32,9 +32,8 @@ from transformers.models.mllama.processing_mllama import ( import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.backends.xformers import XFormersMetadata from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, @@ -828,7 +827,8 @@ class MllamaTextCrossAttention(nn.Module): ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: - if isinstance(attn_metadata, FlashAttentionMetadata): + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) torch.ops._C_cache_ops.reshape_and_cache_flash( @@ -842,7 +842,7 @@ class MllamaTextCrossAttention(nn.Module): 1.0, 1.0, ) - elif isinstance(attn_metadata, XFormersMetadata): + elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -852,9 +852,9 @@ class MllamaTextCrossAttention(nn.Module): attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) else: raise ValueError( - f"Unsupported AttentionMetadata {type(attn_metadata)} " - f"class found. Expected the AttentionMetadata to " - f"be either XFormersMetadata or FlashAttentionMetadata.") + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index 91e615481ff8..ea5ec7b40b95 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,7 +1,5 @@ from typing import TYPE_CHECKING -import openvino as ov -import openvino.properties.hint as hints import torch import vllm.envs as envs @@ -16,6 +14,12 @@ else: logger = init_logger(__name__) +try: + import openvino as ov + import openvino.properties.hint as hints +except ImportError as e: + logger.warning("Failed to import OpenVINO with %r", e) + class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d98bb5a716e9..5f8535eaa303 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -19,7 +19,7 @@ class FlashAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "flash-attn-vllm-v1" + return "FLASH_ATTN_VLLM_V1" @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: