mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 04:55:30 +08:00
[VLM] Add MLA with pure RoPE support for deepseek-vl2 models (#12729)
This commit is contained in:
parent
249824c3bf
commit
98fd089fc9
@ -26,7 +26,8 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_dequantize, scaled_quantize)
|
scaled_dequantize, scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
|
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
@ -174,6 +175,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
|
|
||||||
self.rotary_emb = rotary_emb
|
self.rotary_emb = rotary_emb
|
||||||
|
self.use_yarn_rope = isinstance(rotary_emb,
|
||||||
|
DeepseekScalingRotaryEmbedding)
|
||||||
self.q_proj = q_proj
|
self.q_proj = q_proj
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
self.o_proj = o_proj
|
self.o_proj = o_proj
|
||||||
@ -420,6 +423,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def apply_pure_rope(
|
||||||
|
self,
|
||||||
|
input_positions: torch.Tensor,
|
||||||
|
q_pe: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
seq_len = input_positions.size(0)
|
||||||
|
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
|
||||||
|
|
||||||
|
q_pe, k_pe = self.rotary_emb(
|
||||||
|
input_positions,
|
||||||
|
q_pe.reshape(seq_len, -1),
|
||||||
|
k_pe.reshape(seq_len, -1),
|
||||||
|
)
|
||||||
|
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
|
||||||
|
|
||||||
|
return q_pe, k_pe
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
layer: AttentionLayer,
|
layer: AttentionLayer,
|
||||||
@ -444,13 +465,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
# Restore head dim (for rotary embedding)
|
# Restore head dim (for rotary embedding)
|
||||||
k_pe = k_pe.unsqueeze(1)
|
k_pe = k_pe.unsqueeze(1)
|
||||||
assert hasattr(attn_metadata, "input_positions")
|
assert hasattr(attn_metadata, "input_positions")
|
||||||
|
rope_fn = (self.rotary_emb
|
||||||
|
if self.use_yarn_rope else self.apply_pure_rope)
|
||||||
|
|
||||||
if is_decode:
|
if is_decode:
|
||||||
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||||
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
||||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||||
q_pe, k_pe = \
|
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
|
||||||
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
|
|
||||||
else:
|
else:
|
||||||
assert is_prefill
|
assert is_prefill
|
||||||
q = self.q_proj(hidden_states_or_q_c)[0]\
|
q = self.q_proj(hidden_states_or_q_c)[0]\
|
||||||
@ -458,7 +480,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
|
|
||||||
# TODO(lucas): there must be a nicer way to write this line
|
# TODO(lucas): there must be a nicer way to write this line
|
||||||
q[..., self.qk_nope_head_dim:], k_pe = \
|
q[..., self.qk_nope_head_dim:], k_pe = \
|
||||||
self.rotary_emb(
|
rope_fn(
|
||||||
attn_metadata.input_positions,
|
attn_metadata.input_positions,
|
||||||
q[..., self.qk_nope_head_dim:], k_pe)
|
q[..., self.qk_nope_head_dim:], k_pe)
|
||||||
|
|
||||||
|
|||||||
@ -414,6 +414,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.o_proj")
|
prefix=f"{prefix}.o_proj")
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
|||||||
@ -422,6 +422,7 @@ class DeepseekV3MLAAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.o_proj")
|
prefix=f"{prefix}.o_proj")
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user