From 135c404fbb248a6f09d46e7e7a5ef29992def2ad Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 30 Jan 2025 15:11:58 +0000 Subject: [PATCH] review comments Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/abstract.py | 16 ++++++++++++++++ vllm/attention/backends/mla/utils.py | 9 +++++---- vllm/attention/backends/triton_mla.py | 4 ++-- vllm/model_executor/models/deepseek_v2.py | 2 +- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 7ad242b7001fa..75885947edda9 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -276,3 +276,19 @@ class AttentionImpl(ABC, Generic[T]): output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError + + +class MLAAttentionImpl(AttentionImpl): + + @abstractmethod + def forward( + self, + layer: AttentionLayer, + hidden_states_or_cq: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index a3b45fadffa6b..8f177c91b22eb 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -6,8 +6,9 @@ import torch from vllm import _custom_ops as ops from vllm import envs -from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer, - AttentionMetadata) +from vllm.attention.backends.abstract import (AttentionLayer, + AttentionMetadata, + MLAAttentionImpl) from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -18,11 +19,11 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func @dataclass(kw_only=True) class MLAMetadataCommon(AttentionMetadata): # Input positions for rotrary embeddings since for MLA the rotarty - # position encoding + # position embeddings are applied inside the attention backend input_positions: torch.Tensor -class MLAImplCommon(AttentionImpl): +class MLACommonImpl(MLAAttentionImpl): """ Common class for implementing repeated parts diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index cf9151cd2b30a..3514b18df2d6d 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -20,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionMetadataBuilder, AttentionState, AttentionType) -from vllm.attention.backends.mla.utils import MLAImplCommon, MLAMetadataCommon +from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) @@ -585,7 +585,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): ) -class TritonMLAImpl(MLAImplCommon): +class TritonMLAImpl(MLACommonImpl): def __init__( self, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 28ae50e0770ea..73388cd269853 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -331,7 +331,7 @@ class DeepseekV2MLAAttention(nn.Module): Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - For more info see MLAImplCommon in: vllm/attention/backends/mla/utils.py + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py """ def __init__(