mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:45:15 +08:00
[Bugfix] Avoid import AttentionMetadata explicitly in Mllama (#10593)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
651f6c31ac
commit
04668ebe7a
@ -87,6 +87,11 @@ class BlocksparseParams:
|
|||||||
|
|
||||||
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
# For attention layer compatibility
|
||||||
|
return "FLASH_ATTN"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
|
||||||
return BlocksparseFlashAttentionImpl
|
return BlocksparseFlashAttentionImpl
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionMetadata, AttentionType
|
from vllm.attention import AttentionMetadata, AttentionType
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -98,6 +98,7 @@ class Attention(nn.Module):
|
|||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap)
|
blocksparse_params, logits_soft_cap)
|
||||||
|
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
|
|
||||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||||
# torch.compile works by registering the attention as one giant
|
# torch.compile works by registering the attention as one giant
|
||||||
|
|||||||
@ -32,9 +32,8 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
|
|
||||||
import vllm.distributed.parallel_state as ps
|
import vllm.distributed.parallel_state as ps
|
||||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
|
||||||
from vllm.attention.backends.xformers import XFormersMetadata
|
|
||||||
from vllm.attention.ops.paged_attn import PagedAttention
|
from vllm.attention.ops.paged_attn import PagedAttention
|
||||||
|
from vllm.attention.selector import _Backend
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
||||||
@ -828,7 +827,8 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Skip writing kv-cache for the initial profiling run.
|
# Skip writing kv-cache for the initial profiling run.
|
||||||
if len(kv_cache.shape) > 1:
|
if len(kv_cache.shape) > 1:
|
||||||
if isinstance(attn_metadata, FlashAttentionMetadata):
|
if self.attn.backend in (_Backend.FLASH_ATTN,
|
||||||
|
_Backend.FLASH_ATTN_VLLM_V1):
|
||||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||||
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
@ -842,7 +842,7 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
1.0,
|
1.0,
|
||||||
1.0,
|
1.0,
|
||||||
)
|
)
|
||||||
elif isinstance(attn_metadata, XFormersMetadata):
|
elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA):
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
kv_cache, self.num_local_key_value_heads, self.head_dim)
|
||||||
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
|
||||||
@ -852,9 +852,9 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported AttentionMetadata {type(attn_metadata)} "
|
f"Unsupported Attention backend {self.attn.backend} "
|
||||||
f"class found. Expected the AttentionMetadata to "
|
"enum found. Expected the Attention backend to be "
|
||||||
f"be either XFormersMetadata or FlashAttentionMetadata.")
|
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.")
|
||||||
|
|
||||||
# We have to call torch.sdpa for prefill when using a
|
# We have to call torch.sdpa for prefill when using a
|
||||||
# custom cross-attention mask. Because the mask is not a
|
# custom cross-attention mask. Because the mask is not a
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import openvino as ov
|
|
||||||
import openvino.properties.hint as hints
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -16,6 +14,12 @@ else:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import openvino as ov
|
||||||
|
import openvino.properties.hint as hints
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning("Failed to import OpenVINO with %r", e)
|
||||||
|
|
||||||
|
|
||||||
class OpenVinoPlatform(Platform):
|
class OpenVinoPlatform(Platform):
|
||||||
_enum = PlatformEnum.OPENVINO
|
_enum = PlatformEnum.OPENVINO
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "flash-attn-vllm-v1"
|
return "FLASH_ATTN_VLLM_V1"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user