mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 19:47:11 +08:00
[Attention] MLA move o_proj q_proj into cuda-graph region (#17484)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
afb12e4294
commit
afcb3f8863
@ -281,8 +281,7 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
|||||||
# remove padding
|
# remove padding
|
||||||
output = output.view(-1, self.num_heads,
|
output = output.view(-1, self.num_heads,
|
||||||
q.shape[-1])[..., :v.shape[-1]]
|
q.shape[-1])[..., :v.shape[-1]]
|
||||||
output = output.reshape(-1, self.num_heads * v.shape[-1])
|
return output.reshape(-1, self.num_heads * v.shape[-1])
|
||||||
return self.o_proj(output)[0]
|
|
||||||
|
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
self,
|
self,
|
||||||
@ -303,4 +302,4 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
|||||||
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor)
|
decode_meta.seq_lens_tensor)
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
@ -239,4 +239,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
@ -207,7 +207,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
|||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
@ -1032,12 +1032,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
qk_head_dim: int,
|
qk_head_dim: int,
|
||||||
v_head_dim: int,
|
v_head_dim: int,
|
||||||
rotary_emb: RotaryEmbedding,
|
rotary_emb: RotaryEmbedding,
|
||||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
|
||||||
# attention backend perspective we rely on the layer to pass in the
|
|
||||||
# correct matrix
|
|
||||||
q_proj: ColumnParallelLinear,
|
|
||||||
kv_b_proj: ColumnParallelLinear,
|
kv_b_proj: ColumnParallelLinear,
|
||||||
o_proj: RowParallelLinear,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -1055,9 +1050,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
self.rotary_emb = rotary_emb
|
self.rotary_emb = rotary_emb
|
||||||
self.use_yarn_rope = isinstance(rotary_emb,
|
self.use_yarn_rope = isinstance(rotary_emb,
|
||||||
DeepseekScalingRotaryEmbedding)
|
DeepseekScalingRotaryEmbedding)
|
||||||
self.q_proj = q_proj
|
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
|
||||||
|
|
||||||
self.triton_fa_func = triton_attention
|
self.triton_fa_func = triton_attention
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
@ -1141,27 +1134,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
return attn_out, rest[0]
|
return attn_out, rest[0]
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
x = torch.bmm(x, self.W_UV)
|
x = torch.bmm(x, self.W_UV)
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
return self.o_proj(x)[0]
|
|
||||||
|
|
||||||
# Return `ql_nope`, `q_pe`
|
|
||||||
def _q_proj_and_k_up_proj(self, x):
|
|
||||||
q_nope, q_pe = self.q_proj(x)[0]\
|
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
|
||||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
||||||
|
|
||||||
# Convert from (B, N, P) to (N, B, P)
|
|
||||||
q_nope = q_nope.transpose(0, 1)
|
|
||||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|
||||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
|
||||||
# Convert from (N, B, L) to (B, N, L)
|
|
||||||
return ql_nope.transpose(0, 1), q_pe
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
|
||||||
@ -1345,7 +1324,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(output.flatten(start_dim=-2))[0]
|
return output.flatten(start_dim=-2)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
@ -1360,7 +1339,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
q: torch.Tensor, # query in unified attn
|
||||||
k_c_normed: torch.Tensor, # key in unified attn
|
k_c_normed: torch.Tensor, # key in unified attn
|
||||||
k_pe: torch.Tensor, # value in unified attn
|
k_pe: torch.Tensor, # value in unified attn
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
@ -1391,27 +1370,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
assert hasattr(attn_metadata, "input_positions")
|
assert hasattr(attn_metadata, "input_positions")
|
||||||
|
|
||||||
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
|
||||||
|
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
|
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
|
decode_q = q[num_prefill_tokens:]
|
||||||
decode_k_pe = k_pe[num_prefill_tokens:]
|
decode_k_pe = k_pe[num_prefill_tokens:]
|
||||||
decode_input_positions = \
|
decode_input_positions = \
|
||||||
attn_metadata.input_positions[num_prefill_tokens:]
|
attn_metadata.input_positions[num_prefill_tokens:]
|
||||||
|
|
||||||
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
|
prefill_q = q[:num_prefill_tokens]
|
||||||
prefill_k_pe = k_pe[:num_prefill_tokens]
|
prefill_k_pe = k_pe[:num_prefill_tokens]
|
||||||
prefill_input_positions = \
|
prefill_input_positions = \
|
||||||
attn_metadata.input_positions[:num_prefill_tokens]
|
attn_metadata.input_positions[:num_prefill_tokens]
|
||||||
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_ql_nope, decode_q_pe = \
|
decode_q_nope, decode_q_pe = decode_q.split(
|
||||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
|
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||||
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
|
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||||
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
|
||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||||
@ -1429,9 +1413,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
output = torch.empty(attn_metadata.num_prefill_tokens +
|
output = torch.empty(attn_metadata.num_prefill_tokens +
|
||||||
attn_metadata.num_decode_tokens,
|
attn_metadata.num_decode_tokens,
|
||||||
self.o_proj.output_size,
|
self.v_head_dim * self.num_heads,
|
||||||
device=hidden_states_or_q_c.device,
|
device=q.device,
|
||||||
dtype=hidden_states_or_q_c.dtype)
|
dtype=q.dtype)
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
output[:num_prefill_tokens] = self._forward_prefill(
|
output[:num_prefill_tokens] = self._forward_prefill(
|
||||||
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
|
||||||
|
|||||||
@ -409,4 +409,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
attn_metadata.paged_kv_indices,
|
attn_metadata.paged_kv_indices,
|
||||||
attn_metadata.paged_kv_last_page_lens)
|
attn_metadata.paged_kv_last_page_lens)
|
||||||
|
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
@ -110,4 +110,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
decode_meta.seq_lens_tensor, attn_logits,
|
decode_meta.seq_lens_tensor, attn_logits,
|
||||||
num_kv_splits, self.scale, PAGE_SIZE)
|
num_kv_splits, self.scale, PAGE_SIZE)
|
||||||
|
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
@ -454,9 +454,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
qk_head_dim=self.qk_head_dim,
|
qk_head_dim=self.qk_head_dim,
|
||||||
v_head_dim=self.v_head_dim,
|
v_head_dim=self.v_head_dim,
|
||||||
rotary_emb=self.rotary_emb,
|
rotary_emb=self.rotary_emb,
|
||||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
|
||||||
kv_b_proj=self.kv_b_proj,
|
kv_b_proj=self.kv_b_proj,
|
||||||
o_proj=self.o_proj,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
@ -468,17 +466,22 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
ckq = self.q_a_proj(hidden_states)[0]
|
q_c = self.q_a_proj(hidden_states)[0]
|
||||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
q_c = self.q_a_layernorm(q_c)
|
||||||
|
q = self.q_b_proj(q_c)[0]
|
||||||
else:
|
else:
|
||||||
hidden_states_or_q_c = hidden_states
|
q = self.q_proj(hidden_states)[0]
|
||||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||||
return self.mla_attn(hidden_states_or_q_c,
|
|
||||||
kv_c_normed,
|
attn_out = self.mla_attn(
|
||||||
k_pe,
|
q,
|
||||||
output_shape=hidden_states.shape)
|
kv_c_normed,
|
||||||
|
k_pe,
|
||||||
|
output_shape=(hidden_states.shape[0],
|
||||||
|
self.num_local_heads * self.v_head_dim))
|
||||||
|
return self.o_proj(attn_out)[0]
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2DecoderLayer(nn.Module):
|
class DeepseekV2DecoderLayer(nn.Module):
|
||||||
|
|||||||
@ -200,7 +200,7 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
|||||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
LinearBase, RowParallelLinear,
|
LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -597,12 +597,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
qk_head_dim: int,
|
qk_head_dim: int,
|
||||||
v_head_dim: int,
|
v_head_dim: int,
|
||||||
rotary_emb: RotaryEmbedding,
|
rotary_emb: RotaryEmbedding,
|
||||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
|
||||||
# attention backend perspective we rely on the layer to pass in the
|
|
||||||
# correct matrix
|
|
||||||
q_proj: ColumnParallelLinear,
|
|
||||||
kv_b_proj: ColumnParallelLinear,
|
kv_b_proj: ColumnParallelLinear,
|
||||||
o_proj: RowParallelLinear,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -625,9 +620,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
self.rotary_emb = rotary_emb.forward_cuda
|
self.rotary_emb = rotary_emb.forward_cuda
|
||||||
|
|
||||||
self.q_proj = q_proj
|
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
|
|
||||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||||
@ -684,27 +677,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
return attn_out, lse
|
return attn_out, lse
|
||||||
return attn_out
|
return attn_out
|
||||||
|
|
||||||
def _v_up_proj_and_o_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
x = torch.bmm(x, self.W_UV)
|
x = torch.bmm(x, self.W_UV)
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
return self.o_proj(x)[0]
|
|
||||||
|
|
||||||
# Return `ql_nope`, `q_pe`
|
|
||||||
def _q_proj_and_k_up_proj(self, x):
|
|
||||||
q_nope, q_pe = self.q_proj(x)[0]\
|
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)\
|
|
||||||
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
|
||||||
|
|
||||||
# Convert from (B, N, P) to (N, B, P)
|
|
||||||
q_nope = q_nope.transpose(0, 1)
|
|
||||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
|
||||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
|
||||||
# Convert from (N, B, L) to (B, N, L)
|
|
||||||
return ql_nope.transpose(0, 1), q_pe
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
|
||||||
@ -874,7 +853,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
suffix_lse=suffix_lse,
|
suffix_lse=suffix_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(output.flatten(start_dim=-2))[0]
|
return output.flatten(start_dim=-2)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _forward_decode(
|
def _forward_decode(
|
||||||
@ -889,7 +868,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
q: torch.Tensor,
|
||||||
k_c_normed: torch.Tensor, # key in unified attn
|
k_c_normed: torch.Tensor, # key in unified attn
|
||||||
k_pe: torch.Tensor, # value in unified attn
|
k_pe: torch.Tensor, # value in unified attn
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
@ -908,7 +887,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# Inputs and outputs may be padded for CUDA graphs
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
output_padded = output
|
output_padded = output
|
||||||
output = output[:num_actual_toks, ...]
|
output = output[:num_actual_toks, ...]
|
||||||
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
|
q = q[:num_actual_toks, ...]
|
||||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||||
k_pe = k_pe[:num_actual_toks, ...]
|
k_pe = k_pe[:num_actual_toks, ...]
|
||||||
|
|
||||||
@ -923,24 +902,29 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
has_prefill = attn_metadata.num_prefills > 0
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
|
decode_q = q[:num_decode_tokens]
|
||||||
decode_k_pe = k_pe[:num_decode_tokens]
|
decode_k_pe = k_pe[:num_decode_tokens]
|
||||||
|
|
||||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
prefill_q = q[num_decode_tokens:]
|
||||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
decode_ql_nope, decode_q_pe = \
|
decode_q_nope, decode_q_pe = decode_q.split(
|
||||||
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||||
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
|
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||||
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
|
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||||
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||||
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
|
||||||
.view(-1, self.num_heads, self.qk_head_dim)
|
|
||||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||||
|
|
||||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||||
|
|||||||
@ -146,4 +146,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
@ -115,4 +115,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
attn_metadata.decode.seq_lens, attn_logits,
|
attn_metadata.decode.seq_lens, attn_logits,
|
||||||
num_kv_splits, self.scale, PAGE_SIZE)
|
num_kv_splits, self.scale, PAGE_SIZE)
|
||||||
|
|
||||||
return self._v_up_proj_and_o_proj(o)
|
return self._v_up_proj(o)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user