review comments

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-01-30 15:11:58 +00:00
parent 7241acbd64
commit 135c404fbb
4 changed files with 24 additions and 7 deletions

View File

@ -276,3 +276,19 @@ class AttentionImpl(ABC, Generic[T]):
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
class MLAAttentionImpl(AttentionImpl):
@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -6,8 +6,9 @@ import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
AttentionMetadata)
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -18,11 +19,11 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func
@dataclass(kw_only=True)
class MLAMetadataCommon(AttentionMetadata):
# Input positions for rotrary embeddings since for MLA the rotarty
# position encoding
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
class MLAImplCommon(AttentionImpl):
class MLACommonImpl(MLAAttentionImpl):
"""
Common class for implementing repeated parts

View File

@ -20,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.mla.utils import MLAImplCommon, MLAMetadataCommon
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
@ -585,7 +585,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
)
class TritonMLAImpl(MLAImplCommon):
class TritonMLAImpl(MLACommonImpl):
def __init__(
self,

View File

@ -331,7 +331,7 @@ class DeepseekV2MLAAttention(nn.Module):
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLAImplCommon in: vllm/attention/backends/mla/utils.py
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""
def __init__(