mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 01:21:19 +08:00
review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
7241acbd64
commit
135c404fbb
@ -276,3 +276,19 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -6,8 +6,9 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -18,11 +19,11 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
@dataclass(kw_only=True)
|
||||
class MLAMetadataCommon(AttentionMetadata):
|
||||
# Input positions for rotrary embeddings since for MLA the rotarty
|
||||
# position encoding
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
|
||||
|
||||
class MLAImplCommon(AttentionImpl):
|
||||
class MLACommonImpl(MLAAttentionImpl):
|
||||
"""
|
||||
Common class for implementing repeated parts
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.utils import MLAImplCommon, MLAMetadataCommon
|
||||
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
@ -585,7 +585,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
)
|
||||
|
||||
|
||||
class TritonMLAImpl(MLAImplCommon):
|
||||
class TritonMLAImpl(MLACommonImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -331,7 +331,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
For more info see MLAImplCommon in: vllm/attention/backends/mla/utils.py
|
||||
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user