mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 04:53:33 +08:00
review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
7241acbd64
commit
135c404fbb
@ -276,3 +276,19 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class MLAAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
hidden_states_or_cq: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: T,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -6,8 +6,9 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||||
AttentionMetadata)
|
AttentionMetadata,
|
||||||
|
MLAAttentionImpl)
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -18,11 +19,11 @@ from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
class MLAMetadataCommon(AttentionMetadata):
|
class MLAMetadataCommon(AttentionMetadata):
|
||||||
# Input positions for rotrary embeddings since for MLA the rotarty
|
# Input positions for rotrary embeddings since for MLA the rotarty
|
||||||
# position encoding
|
# position embeddings are applied inside the attention backend
|
||||||
input_positions: torch.Tensor
|
input_positions: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class MLAImplCommon(AttentionImpl):
|
class MLACommonImpl(MLAAttentionImpl):
|
||||||
"""
|
"""
|
||||||
Common class for implementing repeated parts
|
Common class for implementing repeated parts
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionState, AttentionType)
|
AttentionState, AttentionType)
|
||||||
from vllm.attention.backends.mla.utils import MLAImplCommon, MLAMetadataCommon
|
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
|
||||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||||
compute_slot_mapping_start_idx,
|
compute_slot_mapping_start_idx,
|
||||||
is_block_tables_empty)
|
is_block_tables_empty)
|
||||||
@ -585,7 +585,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TritonMLAImpl(MLAImplCommon):
|
class TritonMLAImpl(MLACommonImpl):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -331,7 +331,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||||
|
|
||||||
For more info see MLAImplCommon in: vllm/attention/backends/mla/utils.py
|
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user