mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 10:16:32 +08:00
[Attention] Deepseek v3 MLA support with FP8 compute (#12601)
This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
This commit is contained in:
parent
3e1c76cf3a
commit
baeded2569
@ -1,17 +1,29 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Generic, List, Optional
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl, T)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW8A8Fp8)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize, scaled_quantize)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
"""
|
||||
Common class for implementing repeated parts
|
||||
|
||||
Common class for implementing repeated parts
|
||||
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the entire KV cache.
|
||||
* The attention "simulates" a multi-head attention, while the compute is
|
||||
@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
* V: V head dim.
|
||||
* kv_c: latent/compressed KV
|
||||
* q_c: latent/compressed Q
|
||||
|
||||
|
||||
#
|
||||
# Outside the MLA attention backend
|
||||
#
|
||||
@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
kv_c_k_pe (B, Lkv+R).
|
||||
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
|
||||
and kv_c are normalized.
|
||||
|
||||
|
||||
#
|
||||
# Inside the MLA attention backend
|
||||
#
|
||||
|
||||
* if prefill:
|
||||
|
||||
3. The q_c is then projected up into the multi-head version.
|
||||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
||||
(B, N, P) and q_pe (B, N, R).
|
||||
|
||||
3. The q_c is then projected up into the multi-head version.
|
||||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
||||
(B, N, P) and q_pe (B, N, R).
|
||||
4. q_pe, k_pe are then passed through rotary embeddings.
|
||||
5. kv_c and k_pe are concatenated and inserted into the cache
|
||||
6. The kv_c is then projected up into the multi-head version.
|
||||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
||||
dimensions for K and V, which is split into k_nope (B, N, P)
|
||||
6. The kv_c is then projected up into the multi-head version.
|
||||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
||||
dimensions for K and V, which is split into k_nope (B, N, P)
|
||||
and v (B, N, V).
|
||||
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
|
||||
q_nope, q_pe, k_nope, k_pe.
|
||||
@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
From @tsu-bin's calculation, we only want to use the absorption technique
|
||||
for decode. The prefill algorithm should still use the up-projected MHA
|
||||
for less flops and memory usage.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -162,8 +174,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
return self.o_proj_absorbed(
|
||||
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
|
||||
if is_fp8(self.W_UV_O):
|
||||
output_parallel = apply_fp8_linear_generic(
|
||||
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape)
|
||||
else:
|
||||
output_parallel = torch.matmul(x.flatten(start_dim=1),
|
||||
self.W_UV_O)
|
||||
if self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
return output
|
||||
else:
|
||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
return self.o_proj(x.reshape(-1,
|
||||
@ -171,6 +194,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
if is_fp8(self.W_Q_UK):
|
||||
return apply_fp8_linear_generic(
|
||||
x, self.W_Q_UK, self.W_Q_UK_scales,
|
||||
self.reqaunt_input_group_shape,
|
||||
self.reqaunt_weight_group_shape).view(
|
||||
-1, self.num_heads, self.kv_lora_rank)
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
@ -179,8 +208,91 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.T
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def is_layer_fp8(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, Fp8LinearMethod) or\
|
||||
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
|
||||
|
||||
def quantization_scheme_supported(layer: LinearBase) -> bool:
|
||||
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
|
||||
is_layer_fp8(layer)
|
||||
|
||||
# TODO(lucas) This is very gross, we need a more wide scale refactor of
|
||||
# all the FP8 code with a more standard way of
|
||||
# defining schemes/group-shapes, we should also potentially force
|
||||
# quant_methods to support a decompress function
|
||||
#
|
||||
# returns input_group_shape, weight_group_shape
|
||||
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
|
||||
Tuple[Tuple[int, int], Tuple[int, int]]:
|
||||
if isinstance(layer.quant_method, Fp8LinearMethod):
|
||||
if layer.quant_method.block_quant is not None:
|
||||
weight_block_size = \
|
||||
layer.quant_method.quant_config.weight_block_size
|
||||
# per-token-group (1, X), block-quantized (X, Y)
|
||||
return (1, weight_block_size[-1]), weight_block_size
|
||||
else:
|
||||
return (-1, -1), (-1, -1) # per-tensor, per-tensor
|
||||
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
|
||||
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# this is hacky but we always assume the for
|
||||
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
|
||||
# we ignore if it is static-per-tensor since we are going to
|
||||
# requantize after later anyways
|
||||
strategy = layer.scheme.strategy
|
||||
if strategy == QuantizationStrategy.TENSOR:
|
||||
return (1, -1), (-1, -1) # per-token, per-tensor
|
||||
elif strategy == QuantizationStrategy.CHANNEL:
|
||||
return (1, -1), (-1, 1) # per-token, per-channel
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"QuantizationStrategy.{strategy} is not supported for "
|
||||
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Can't determine scale group shapes for "
|
||||
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
|
||||
)
|
||||
|
||||
def get_scales(layer: LinearBase) -> torch.Tensor:
|
||||
if hasattr(layer, "weight_scale_inv"):
|
||||
return layer.weight_scale_inv
|
||||
return layer.weight_scale
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if is_layer_fp8(layer):
|
||||
if isinstance(layer.quant_method, \
|
||||
CompressedTensorsLinearMethod) and \
|
||||
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
|
||||
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
|
||||
# seems to store weights as (input, output) instead of
|
||||
# (output, input) so we need to transpose
|
||||
weight = layer.weight.T # standardize to (output, input)
|
||||
else:
|
||||
weight = layer.weight
|
||||
_, weight_scale_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(layer)
|
||||
scales = get_scales(layer)
|
||||
|
||||
return scaled_dequantize(weight, scales,
|
||||
weight_scale_group_shape)
|
||||
else:
|
||||
return layer.weight
|
||||
|
||||
if not (quantization_scheme_supported(self.kv_b_proj) and\
|
||||
quantization_scheme_supported(self.q_proj) and\
|
||||
quantization_scheme_supported(self.o_proj)):
|
||||
raise NotImplementedError(
|
||||
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
|
||||
", please run with VLLM_MLA_DISABLE=1")
|
||||
|
||||
weight_dtype = self.kv_b_proj.weight.dtype
|
||||
assert self.o_proj.weight.dtype == weight_dtype
|
||||
assert self.q_proj.weight.dtype == weight_dtype
|
||||
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
@ -198,18 +310,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj = self.q_proj.weight.T\
|
||||
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
|
||||
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
# W_QR is small so for simplicity we dont bother requantizing it
|
||||
self.W_QR = self.W_QR.to(act_dtype)
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
# This assumes it wise to requantize using the same group shapes
|
||||
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
|
||||
# weights were originally quantized
|
||||
requant_input_group_shape, requant_weight_group_shape = \
|
||||
get_scale_group_shapes_for_fp8(self.q_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
|
||||
assert (requant_input_group_shape, requant_weight_group_shape)\
|
||||
== get_scale_group_shapes_for_fp8(self.o_proj)
|
||||
self.reqaunt_input_group_shape = requant_input_group_shape
|
||||
self.reqaunt_weight_group_shape = requant_weight_group_shape
|
||||
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
@ -223,25 +352,44 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
W_O = self.o_proj.weight\
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||
W_Q_UK,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
|
||||
else:
|
||||
self.W_Q_UK = W_Q_UK.to(act_dtype)
|
||||
|
||||
W_O = get_and_maybe_dequant_weights(self.o_proj)\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.o_proj_absorbed = RowParallelLinear(
|
||||
self.W_UV_O.shape[0] * tp_size,
|
||||
self.W_UV_O.shape[1],
|
||||
bias=False,
|
||||
# TODO(lucas) figure out how to properly forward quant_method
|
||||
#quant_config=self.o_proj.quant_method,
|
||||
)
|
||||
if is_fp8(weight_dtype) and requantization_enabled:
|
||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||
W_UV_O,
|
||||
self.reqaunt_weight_group_shape,
|
||||
quant_dtype=current_platform_fp8_dtype)
|
||||
# For FP8 save the transpose so we can use
|
||||
# `apply_w8a8_block_fp8_linear` directly
|
||||
self.W_UV_O = W_UV_O.T.contiguous()
|
||||
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
|
||||
else:
|
||||
self.W_UV_O = W_UV_O.to(act_dtype)
|
||||
|
||||
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
else:
|
||||
if is_fp8(weight_dtype):
|
||||
raise NotImplementedError(
|
||||
"Currently fp8 requires matrix absorption")
|
||||
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
|
||||
@ -57,14 +57,12 @@ class TritonMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
kv_lora_rank: int, # passed via head_size
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
|
||||
k_pe_size = kv_lora_rank // 8
|
||||
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@ -83,7 +81,7 @@ class TritonMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [512]
|
||||
return [576]
|
||||
|
||||
|
||||
class TritonMLAState(AttentionState):
|
||||
@ -624,8 +622,6 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
num_kv_splits = 8
|
||||
|
||||
return TritonMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
@ -645,7 +641,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
num_kv_splits=num_kv_splits,
|
||||
num_kv_splits=4, # TODO(lucas) add heuristic
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
@ -200,9 +200,9 @@ class Attention(nn.Module):
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading()
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
|
||||
@ -739,18 +739,19 @@ class ModelConfig:
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
# TODO add deepseek_v3
|
||||
return hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2'))
|
||||
return (hasattr(self.hf_text_config, "model_type")) \
|
||||
and (self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3'))\
|
||||
and (self.hf_text_config.kv_lora_rank is not None)
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if self.is_deepseek_mla:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
if self.use_mla:
|
||||
return self.hf_text_config.kv_lora_rank
|
||||
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
|
||||
else:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_rope_head_dim", 0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_nope_head_dim", 0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
@ -969,6 +970,32 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
if self.quantization is not None and self.quantization not in [\
|
||||
"fp8", "compressed-tensors"]:
|
||||
logger.warning(
|
||||
"MLA is not supported with %s quantization. "
|
||||
"Disabling MLA.", self.quantization)
|
||||
return False
|
||||
|
||||
# If using a "compressed-tensors" checkpoint, check that all groups
|
||||
# have fp8 for both weights and activations.
|
||||
if self.quantization == "compressed-tensors":
|
||||
quant_config = self._parse_quant_hf_config()
|
||||
for group_name, cfg in quant_config.get("config_groups",
|
||||
("", {})).items():
|
||||
act_cfg = cfg.get("input_activations", {})
|
||||
act_type = None if act_cfg is None else act_cfg.get("type", "")
|
||||
w_cfg = cfg.get("weights", {})
|
||||
w_type = None if w_cfg is None else w_cfg.get("type", "")
|
||||
if act_type != "fp8" or w_type != "fp8":
|
||||
logger.warning(
|
||||
"compressed-tensors MLA support requires fp8 "
|
||||
"activations and weights in group '%s', but got "
|
||||
"activations type '%s' and weights type '%s'.\n "
|
||||
"Full config: %s", group_name, act_type, w_type,
|
||||
quant_config)
|
||||
return False
|
||||
|
||||
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
||||
return use_mla
|
||||
|
||||
|
||||
12
vllm/envs.py
12
vllm/envs.py
@ -79,6 +79,7 @@ if TYPE_CHECKING:
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -519,7 +520,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
||||
# the is enabled by default
|
||||
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
|
||||
|
||||
# When running MLA with matrix-absorption enabled and fp8 quantized weights
|
||||
# we perform the matrix-absorption in float32 precision, after the matrices
|
||||
# are absorbed we requantize the weights back to fp8, this flag can be used
|
||||
# to disable the requantization step, and instead convert the absorbed
|
||||
# matrices to match the activation type. This can lead to higher memory and
|
||||
# compute usage but better preserves the accuracy of the original model.
|
||||
"VLLM_MLA_DISABLE_REQUANTIZATION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -10,10 +10,24 @@ import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
_normalize_quant_group_shape, scaled_dequantize)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
|
||||
if current_platform.is_rocm() else
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.dtype
|
||||
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
@ -55,6 +69,42 @@ def apply_w8a8_block_fp8_linear(
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
|
||||
# Unify the interface between `apply_w8a8_block_fp8_linear` and
|
||||
# `apply_fp8_linear`
|
||||
# NOTE(lucas): this is quite messy, we should think through this more formally
|
||||
def apply_fp8_linear_generic(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_group_shape: Tuple[int, int],
|
||||
weight_group_shape: Tuple[int, int],
|
||||
input_scale: Optional[torch.Tensor] = None, # static scale if one
|
||||
) -> torch.Tensor:
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input = input.view(-1, input.shape[-1])
|
||||
|
||||
weight_group_shape = _normalize_quant_group_shape(\
|
||||
weight, weight_group_shape)
|
||||
input_group_shape = _normalize_quant_group_shape(input, input_group_shape)
|
||||
|
||||
def is_dim_blocked(dim, shape, group_shape):
|
||||
return group_shape < shape[dim] and group_shape > 1
|
||||
|
||||
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
|
||||
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
|
||||
input_group_shape == (1, weight_group_shape[1]):
|
||||
return apply_w8a8_block_fp8_linear(input, weight,
|
||||
list(weight_group_shape),
|
||||
weight_scale)
|
||||
else:
|
||||
# Despite having linear in the it doesn't conform to
|
||||
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
|
||||
# so we explicitly transpose the weight matrix here
|
||||
return apply_fp8_linear(input, weight.T, weight_scale.T,
|
||||
use_per_token_if_dynamic=\
|
||||
(input_group_shape == (1, input.shape[1])))
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
x: torch.Tensor,
|
||||
dtype: Optional[torch.dtype] = None
|
||||
@ -75,7 +125,6 @@ def input_to_float8(
|
||||
def block_quant_to_tensor_quant(
|
||||
x_q_block: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
block_size: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""This function converts block-wise quantization to tensor-wise
|
||||
quantization. The inputs are block-wise quantization tensor `x_q_block`,
|
||||
@ -83,26 +132,7 @@ def block_quant_to_tensor_quant(
|
||||
The outputs are tensor-wise quantization tensor and tensor-wise
|
||||
quantization scale. Note only float8 is supported for now.
|
||||
"""
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n, k = x_q_block.shape
|
||||
n_tiles = (n + block_n - 1) // block_n
|
||||
k_tiles = (k + block_k - 1) // block_k
|
||||
assert n_tiles == x_s.shape[0]
|
||||
assert k_tiles == x_s.shape[1]
|
||||
|
||||
x_dq_block = x_q_block.to(torch.float32)
|
||||
|
||||
x_dq_block_tiles = [[
|
||||
x_dq_block[
|
||||
j * block_n:min((j + 1) * block_n, n),
|
||||
i * block_k:min((i + 1) * block_k, k),
|
||||
] for i in range(k_tiles)
|
||||
] for j in range(n_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
||||
|
||||
x_dq_block = scaled_dequantize(x_q_block, x_s)
|
||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
return x_q_tensor, scale
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -20,6 +20,120 @@ FUSED_LAYER_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
# Normalize the group_shape to the full extent for any dims that are -1
|
||||
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
|
||||
int]):
|
||||
# -1 means full extent
|
||||
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
|
||||
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
|
||||
|
||||
|
||||
# Useful when treating N-dimensional group scaling as extended numpy-style
|
||||
# broadcasting in numpy simply stretches dimensions with an extent of 1 to match
|
||||
# the target shape by repeating the data along that dimension (broadcasting)
|
||||
# , we extend these semantics to say if the extent of a dimension in the
|
||||
# source shape is not 1 and does not match the target shape we repeat each
|
||||
# element along that dimension src_shape[dim] // target_shape[dim] times
|
||||
# example if we have:
|
||||
# a = [[1, 2], and target_shape = (2, 4)
|
||||
# [3, 4]]
|
||||
# then we would expand a to:
|
||||
# a = [[1, 1, 2, 2],
|
||||
# [3, 3, 4, 4]]
|
||||
# NOTE this function this function does not explicitly broadcast dimensions
|
||||
# with an extent of 1, since this can be done implicitly by pytorch
|
||||
def group_broadcast(t, shape):
|
||||
for i, s in enumerate(shape):
|
||||
if t.shape[i] != s and t.shape[i] != 1:
|
||||
assert s % t.shape[i] == 0
|
||||
t = t.unsqueeze(i + 1)\
|
||||
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
||||
.flatten(i, i + 1)
|
||||
return t
|
||||
|
||||
|
||||
# Quantize assuming once scale per group of elements with shape group_shape,
|
||||
# example group shapes:
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
# * (1, -1) for per-row quantization
|
||||
# * (-1, 1) for per-column quantization
|
||||
# * (128, 128) for 128x128 deepseek style block quantization
|
||||
# * (1, 128) for deepseek style activation quantization
|
||||
# (i.e. per-token-per-group)
|
||||
def scaled_quantize(
|
||||
x: torch.Tensor,
|
||||
group_shape: Tuple[int, int],
|
||||
quant_dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
group_shape = _normalize_quant_group_shape(x, group_shape)
|
||||
assert quant_dtype.is_floating_point, \
|
||||
"currently `scaled_quantize` only supports floating point dtypes " \
|
||||
"but could be extended to support other dtypes"
|
||||
|
||||
finfo = torch.finfo(quant_dtype)
|
||||
|
||||
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
|
||||
assert x.ndim == 2
|
||||
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
|
||||
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
|
||||
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
|
||||
|
||||
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
|
||||
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
|
||||
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
|
||||
|
||||
# Compute scales
|
||||
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax
|
||||
|
||||
# Apply scale and convert form:
|
||||
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
|
||||
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
|
||||
.clamp(min=finfo.min, max=finfo.max)\
|
||||
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])\
|
||||
.permute(0, 2, 1, 3)\
|
||||
.reshape(x.shape)
|
||||
|
||||
return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()
|
||||
|
||||
|
||||
# inverses `scaled_quantize`
|
||||
def scaled_dequantize(
|
||||
x_q: torch.Tensor,
|
||||
x_s: torch.Tensor,
|
||||
group_shape: Optional[Tuple[int, int]] = None,
|
||||
out_dtype: torch.dtype = torch.float32,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if group_shape is not None:
|
||||
group_shape = _normalize_quant_group_shape(x_q, group_shape)
|
||||
|
||||
if x_s.ndim == 0: # scalar
|
||||
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
|
||||
if x_s.ndim == 1:
|
||||
if group_shape is None:
|
||||
raise AssertionError(
|
||||
"if x_s is 1D tensor, group_shape must be provided otherwise "
|
||||
"its ambiguous which dimension to broadcast x_s to")
|
||||
# unsqueeze the scales for the dimension where we want to broadcast
|
||||
# across the full extent
|
||||
if group_shape[0] == x_q.shape[-2]:
|
||||
x_s = x_s.unsqueeze(-2)
|
||||
elif group_shape[1] == x_q.shape[-1]:
|
||||
x_s = x_s.unsqueeze(-1)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"if x_s is a vector we should be broadcasting it to the full "
|
||||
"extent of one of the dimensions")
|
||||
|
||||
if group_shape is not None:
|
||||
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
|
||||
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
|
||||
x_s = group_broadcast(x_s.to(torch.float32), x_q.shape)
|
||||
return (x_q.to(torch.float32) * x_s).to(out_dtype)
|
||||
|
||||
|
||||
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
|
||||
@ -398,11 +398,13 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
elif isinstance(module, Attention) and \
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading()
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
return model.eval()
|
||||
|
||||
|
||||
@ -439,6 +441,11 @@ class DummyModelLoader(BaseModelLoader):
|
||||
with device_loading_context(
|
||||
module, torch.device(device_config.device)):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
return model.eval()
|
||||
|
||||
|
||||
@ -633,6 +640,12 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading(
|
||||
model_config.dtype)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
@ -1272,7 +1285,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
@ -1369,6 +1382,11 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
if quant_method is not None:
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
return model.eval()
|
||||
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@ -333,12 +333,156 @@ class DeepseekV3Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV3MLAAttention(nn.Module):
|
||||
"""
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@ -351,7 +495,11 @@ class DeepseekV3DecoderLayer(nn.Module):
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.self_attn = DeepseekV3Attention(
|
||||
if model_config.use_mla:
|
||||
attn_cls = DeepseekV3MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV3Attention
|
||||
self.self_attn = attn_cls(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -428,6 +576,7 @@ class DeepseekV3Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@ -447,6 +596,7 @@ class DeepseekV3Model(nn.Module):
|
||||
lambda prefix: DeepseekV3DecoderLayer(
|
||||
config,
|
||||
prefix,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
|
||||
@ -110,7 +110,9 @@ class CacheEngine:
|
||||
parallel_config, LayerBlockType.attention)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
# For MLA there is no value cache, since the latent vector
|
||||
# is joint keys and values.
|
||||
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||
total = num_attention_layers * (key_cache_block + value_cache_block)
|
||||
if cache_config.cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user