From baeded25699f9f4851843306f27f685c4d4ee7c5 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Sat, 1 Feb 2025 00:52:51 -0500 Subject: [PATCH] [Attention] Deepseek v3 MLA support with FP8 compute (#12601) This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights --------- Signed-off-by: Lucas Wilkinson Co-authored-by: Woosuk Kwon Co-authored-by: simon-mo Co-authored-by: Michael Goin Co-authored-by: Zhuohan Li Co-authored-by: Tyler Michael Smith Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> --- vllm/attention/backends/mla/utils.py | 218 +++++++++++++++--- vllm/attention/backends/triton_mla.py | 18 +- vllm/attention/layer.py | 4 +- vllm/config.py | 39 +++- vllm/envs.py | 12 +- .../layers/quantization/utils/fp8_utils.py | 74 ++++-- .../layers/quantization/utils/quant_utils.py | 116 +++++++++- vllm/model_executor/model_loader/loader.py | 24 +- vllm/model_executor/models/deepseek_v3.py | 154 ++++++++++++- vllm/worker/cache_engine.py | 4 +- 10 files changed, 579 insertions(+), 84 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index c6c8a6034e20f..e8fec234c0225 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -1,17 +1,29 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional +from typing import Any, Dict, Generic, List, Optional, Tuple import torch +from compressed_tensors.quantization import QuantizationStrategy from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) + LinearBase, RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsLinearMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsW8A8Fp8) +from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + scaled_dequantize, scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata): class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): """ - Common class for implementing repeated parts - + Common class for implementing repeated parts + Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). - + Deepseek's MLA attention works the following way: * Use a single latent vector to represent the entire KV cache. * The attention "simulates" a multi-head attention, while the compute is @@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): * V: V head dim. * kv_c: latent/compressed KV * q_c: latent/compressed Q - + # # Outside the MLA attention backend # @@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): kv_c_k_pe (B, Lkv+R). 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq and kv_c are normalized. - + # # Inside the MLA attention backend # * if prefill: - - 3. The q_c is then projected up into the multi-head version. - * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope - (B, N, P) and q_pe (B, N, R). + + 3. The q_c is then projected up into the multi-head version. + * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope + (B, N, P) and q_pe (B, N, R). 4. q_pe, k_pe are then passed through rotary embeddings. 5. kv_c and k_pe are concatenated and inserted into the cache - 6. The kv_c is then projected up into the multi-head version. - * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope - dimensions for K and V, which is split into k_nope (B, N, P) + 6. The kv_c is then projected up into the multi-head version. + * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope + dimensions for K and V, which is split into k_nope (B, N, P) and v (B, N, V). 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from q_nope, q_pe, k_nope, k_pe. @@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): From @tsu-bin's calculation, we only want to use the absorption technique for decode. The prefill algorithm should still use the up-projected MHA for less flops and memory usage. - + """ def __init__( @@ -162,8 +174,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - return self.o_proj_absorbed( - x.reshape(-1, self.num_heads * self.kv_lora_rank))[0] + if is_fp8(self.W_UV_O): + output_parallel = apply_fp8_linear_generic( + x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape) + else: + output_parallel = torch.matmul(x.flatten(start_dim=1), + self.W_UV_O) + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + return output else: x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) return self.o_proj(x.reshape(-1, @@ -171,6 +194,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): def _q_proj_and_k_up_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + if is_fp8(self.W_Q_UK): + return apply_fp8_linear_generic( + x, self.W_Q_UK, self.W_Q_UK_scales, + self.reqaunt_input_group_shape, + self.reqaunt_weight_group_shape).view( + -1, self.num_heads, self.kv_lora_rank) return torch.matmul(x, self.W_Q_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) else: @@ -179,8 +208,91 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) - def process_weights_after_loading(self): - kv_b_proj_weight = self.kv_b_proj.weight.T + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def is_layer_fp8(layer: LinearBase) -> bool: + return isinstance(layer.quant_method, Fp8LinearMethod) or\ + (isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8)) + + def quantization_scheme_supported(layer: LinearBase) -> bool: + return isinstance(layer.quant_method, UnquantizedLinearMethod) or \ + is_layer_fp8(layer) + + # TODO(lucas) This is very gross, we need a more wide scale refactor of + # all the FP8 code with a more standard way of + # defining schemes/group-shapes, we should also potentially force + # quant_methods to support a decompress function + # + # returns input_group_shape, weight_group_shape + def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ + Tuple[Tuple[int, int], Tuple[int, int]]: + if isinstance(layer.quant_method, Fp8LinearMethod): + if layer.quant_method.block_quant is not None: + weight_block_size = \ + layer.quant_method.quant_config.weight_block_size + # per-token-group (1, X), block-quantized (X, Y) + return (1, weight_block_size[-1]), weight_block_size + else: + return (-1, -1), (-1, -1) # per-tensor, per-tensor + elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ + and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # this is hacky but we always assume the for + # CompressedTensorsW8A8Fp8 the input is dynamic per-token + # we ignore if it is static-per-tensor since we are going to + # requantize after later anyways + strategy = layer.scheme.strategy + if strategy == QuantizationStrategy.TENSOR: + return (1, -1), (-1, -1) # per-token, per-tensor + elif strategy == QuantizationStrategy.CHANNEL: + return (1, -1), (-1, 1) # per-token, per-channel + else: + raise NotImplementedError( + f"QuantizationStrategy.{strategy} is not supported for " + "fp8 MLA, please run with VLLM_MLA_DISABLE=1") + else: + raise NotImplementedError( + "Can't determine scale group shapes for " + f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" + ) + + def get_scales(layer: LinearBase) -> torch.Tensor: + if hasattr(layer, "weight_scale_inv"): + return layer.weight_scale_inv + return layer.weight_scale + + def get_and_maybe_dequant_weights(layer: LinearBase): + if is_layer_fp8(layer): + if isinstance(layer.quant_method, \ + CompressedTensorsLinearMethod) and \ + isinstance(layer.scheme, CompressedTensorsW8A8Fp8): + # NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8` + # seems to store weights as (input, output) instead of + # (output, input) so we need to transpose + weight = layer.weight.T # standardize to (output, input) + else: + weight = layer.weight + _, weight_scale_group_shape = \ + get_scale_group_shapes_for_fp8(layer) + scales = get_scales(layer) + + return scaled_dequantize(weight, scales, + weight_scale_group_shape) + else: + return layer.weight + + if not (quantization_scheme_supported(self.kv_b_proj) and\ + quantization_scheme_supported(self.q_proj) and\ + quantization_scheme_supported(self.o_proj)): + raise NotImplementedError( + "Only FP8 and UnquantizedLinearMethod are supported for MLA" + ", please run with VLLM_MLA_DISABLE=1") + + weight_dtype = self.kv_b_proj.weight.dtype + assert self.o_proj.weight.dtype == weight_dtype + assert self.q_proj.weight.dtype == weight_dtype + + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( @@ -198,18 +310,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - q_proj = self.q_proj.weight.T\ + q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ .view(-1, self.num_heads, self.qk_head_dim) # can be W_Q or W_UQ depending q_lora_rank, the former if # q_lora_rank is None, the latter otherwise. From the Attention backend # perspective though we call these both W_Q and rely on the layer # to pass in the correct matrix - W_Q = q_proj[..., :self.qk_nope_head_dim] - self.W_QR = q_proj[..., self.qk_nope_head_dim:]\ + W_Q = q_proj_weight[..., :self.qk_nope_head_dim] + self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ .flatten(start_dim=1).contiguous() + # W_QR is small so for simplicity we dont bother requantizing it + self.W_QR = self.W_QR.to(act_dtype) + if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: + requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION + if is_fp8(weight_dtype) and requantization_enabled: + # This assumes it wise to requantize using the same group shapes + # (i.e. strategy, per-tensor, per-channel, block etc.) that the + # weights were originally quantized + requant_input_group_shape, requant_weight_group_shape = \ + get_scale_group_shapes_for_fp8(self.q_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.kv_b_proj) + assert (requant_input_group_shape, requant_weight_group_shape)\ + == get_scale_group_shapes_for_fp8(self.o_proj) + self.reqaunt_input_group_shape = requant_input_group_shape + self.reqaunt_weight_group_shape = requant_weight_group_shape + # # Perform matrix-absorption following # https://github.com/flashinfer-ai/flashinfer/pull/551 @@ -223,25 +352,44 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # latter otherwise # basically if q_lora_rank is none we are absorbing into q_proj # instead of UQ - self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ + W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ .flatten(start_dim=1).contiguous() - W_O = self.o_proj.weight\ + if is_fp8(weight_dtype) and requantization_enabled: + W_Q_UK, W_Q_UK_scales = scaled_quantize( + W_Q_UK, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_Q_UK = W_Q_UK.T.contiguous() + self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() + else: + self.W_Q_UK = W_Q_UK.to(act_dtype) + + W_O = get_and_maybe_dequant_weights(self.o_proj)\ .view(-1, self.num_heads, self.v_head_dim) - self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ + W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ .flatten(start_dim=0, end_dim=1).contiguous() - tp_size = get_tensor_model_parallel_world_size() - self.o_proj_absorbed = RowParallelLinear( - self.W_UV_O.shape[0] * tp_size, - self.W_UV_O.shape[1], - bias=False, - # TODO(lucas) figure out how to properly forward quant_method - #quant_config=self.o_proj.quant_method, - ) + if is_fp8(weight_dtype) and requantization_enabled: + W_UV_O, W_UV_O_scales = scaled_quantize( + W_UV_O, + self.reqaunt_weight_group_shape, + quant_dtype=current_platform_fp8_dtype) + # For FP8 save the transpose so we can use + # `apply_w8a8_block_fp8_linear` directly + self.W_UV_O = W_UV_O.T.contiguous() + self.W_UV_O_scales = W_UV_O_scales.T.contiguous() + else: + self.W_UV_O = W_UV_O.to(act_dtype) - self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T) + self.tp_size = get_tensor_model_parallel_world_size() else: + if is_fp8(weight_dtype): + raise NotImplementedError( + "Currently fp8 requires matrix absorption") + self.W_UV = W_UV self.W_UK = W_UK self.W_Q = W_Q.flatten(start_dim=1) diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index da09bb70b4f1a..95dc119a47bb5 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -57,14 +57,12 @@ class TritonMLABackend(AttentionBackend): @staticmethod def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - kv_lora_rank: int, # passed via head_size + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, ) -> Tuple[int, ...]: - # TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank - k_pe_size = kv_lora_rank // 8 - return (num_blocks, block_size, kv_lora_rank + k_pe_size) + return (num_blocks, block_size, head_size) @staticmethod def swap_blocks( @@ -83,7 +81,7 @@ class TritonMLABackend(AttentionBackend): @staticmethod def get_supported_head_sizes() -> List[int]: - return [512] + return [576] class TritonMLAState(AttentionState): @@ -624,8 +622,6 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): self.multimodal_placeholder_maps.items() } - num_kv_splits = 8 - return TritonMLAMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -645,7 +641,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]): context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, - num_kv_splits=num_kv_splits, + num_kv_splits=4, # TODO(lucas) add heuristic head_dim=self.runner.model_config.get_head_size(), ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 9b804a29a485d..b97165f625e51 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -200,9 +200,9 @@ class Attention(nn.Module): s += f", backend={self.impl.__class__.__name__}" return s - def process_weights_after_loading(self): + def process_weights_after_loading(self, act_dtype: torch.dtype): if hasattr(self.impl, "process_weights_after_loading"): - self.impl.process_weights_after_loading() + self.impl.process_weights_after_loading(act_dtype) class MultiHeadAttention(nn.Module): diff --git a/vllm/config.py b/vllm/config.py index f6bd8b1ad8f14..f998502eef0da 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -739,18 +739,19 @@ class ModelConfig: @property def is_deepseek_mla(self) -> bool: # TODO add deepseek_v3 - return hasattr(self.hf_text_config, - "model_type") and (self.hf_text_config.model_type - in ('deepseek_v2')) + return (hasattr(self.hf_text_config, "model_type")) \ + and (self.hf_text_config.model_type in \ + ('deepseek_v2', 'deepseek_v3'))\ + and (self.hf_text_config.kv_lora_rank is not None) def get_head_size(self) -> int: # TODO remove hard code if self.is_deepseek_mla: + qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", + 0) if self.use_mla: - return self.hf_text_config.kv_lora_rank + return self.hf_text_config.kv_lora_rank + qk_rope_head_dim else: - qk_rope_head_dim = getattr(self.hf_text_config, - "qk_rope_head_dim", 0) qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim", 0) if qk_rope_head_dim and qk_nope_head_dim: @@ -969,6 +970,32 @@ class ModelConfig: @property def use_mla(self) -> bool: + if self.quantization is not None and self.quantization not in [\ + "fp8", "compressed-tensors"]: + logger.warning( + "MLA is not supported with %s quantization. " + "Disabling MLA.", self.quantization) + return False + + # If using a "compressed-tensors" checkpoint, check that all groups + # have fp8 for both weights and activations. + if self.quantization == "compressed-tensors": + quant_config = self._parse_quant_hf_config() + for group_name, cfg in quant_config.get("config_groups", + ("", {})).items(): + act_cfg = cfg.get("input_activations", {}) + act_type = None if act_cfg is None else act_cfg.get("type", "") + w_cfg = cfg.get("weights", {}) + w_type = None if w_cfg is None else w_cfg.get("type", "") + if act_type != "fp8" or w_type != "fp8": + logger.warning( + "compressed-tensors MLA support requires fp8 " + "activations and weights in group '%s', but got " + "activations type '%s' and weights type '%s'.\n " + "Full config: %s", group_name, act_type, w_type, + quant_config) + return False + use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE) return use_mla diff --git a/vllm/envs.py b/vllm/envs.py index 2a18e3b9bc51d..25098070b00c9 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -79,6 +79,7 @@ if TYPE_CHECKING: VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True + VLLM_MLA_DISABLE_REQUANTIZATION: bool = False def get_default_cache_root(): @@ -519,7 +520,16 @@ environment_variables: Dict[str, Callable[[], Any]] = { # storing more weights, W_Q_UK and W_UV_O, so can increase memory usage, # the is enabled by default "VLLM_MLA_PERFORM_MATRIX_ABSORPTION": - lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))) + lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))), + + # When running MLA with matrix-absorption enabled and fp8 quantized weights + # we perform the matrix-absorption in float32 precision, after the matrices + # are absorbed we requantize the weights back to fp8, this flag can be used + # to disable the requantization step, and instead convert the absorbed + # matrices to match the activation type. This can lead to higher memory and + # compute usage but better preserves the accuracy of the original model. + "VLLM_MLA_DISABLE_REQUANTIZATION": + lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))) } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ccebff341a7ed..850820f66ff90 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -2,7 +2,7 @@ import functools import json import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch import triton @@ -10,10 +10,24 @@ import triton.language as tl from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + _normalize_quant_group_shape, scaled_dequantize) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear) from vllm.platforms import current_platform logger = init_logger(__name__) +current_platform_fp8_dtype = (torch.float8_e4m3fnuz + if current_platform.is_rocm() else + torch.float8_e4m3fn) + + +def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: + if isinstance(x, torch.Tensor): + x = x.dtype + return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz + def apply_w8a8_block_fp8_linear( input: torch.Tensor, @@ -55,6 +69,42 @@ def apply_w8a8_block_fp8_linear( return output.to(dtype=input.dtype).view(*output_shape) +# Unify the interface between `apply_w8a8_block_fp8_linear` and +# `apply_fp8_linear` +# NOTE(lucas): this is quite messy, we should think through this more formally +def apply_fp8_linear_generic( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_group_shape: Tuple[int, int], + weight_group_shape: Tuple[int, int], + input_scale: Optional[torch.Tensor] = None, # static scale if one +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input = input.view(-1, input.shape[-1]) + + weight_group_shape = _normalize_quant_group_shape(\ + weight, weight_group_shape) + input_group_shape = _normalize_quant_group_shape(input, input_group_shape) + + def is_dim_blocked(dim, shape, group_shape): + return group_shape < shape[dim] and group_shape > 1 + + if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ + and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ + input_group_shape == (1, weight_group_shape[1]): + return apply_w8a8_block_fp8_linear(input, weight, + list(weight_group_shape), + weight_scale) + else: + # Despite having linear in the it doesn't conform to + # `torch.nn.functional.linear` which is defined as `input @ weight.T` + # so we explicitly transpose the weight matrix here + return apply_fp8_linear(input, weight.T, weight_scale.T, + use_per_token_if_dynamic=\ + (input_group_shape == (1, input.shape[1]))) + + def input_to_float8( x: torch.Tensor, dtype: Optional[torch.dtype] = None @@ -75,7 +125,6 @@ def input_to_float8( def block_quant_to_tensor_quant( x_q_block: torch.Tensor, x_s: torch.Tensor, - block_size: List[int], ) -> Tuple[torch.Tensor, torch.Tensor]: """This function converts block-wise quantization to tensor-wise quantization. The inputs are block-wise quantization tensor `x_q_block`, @@ -83,26 +132,7 @@ def block_quant_to_tensor_quant( The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. Note only float8 is supported for now. """ - block_n, block_k = block_size[0], block_size[1] - n, k = x_q_block.shape - n_tiles = (n + block_n - 1) // block_n - k_tiles = (k + block_k - 1) // block_k - assert n_tiles == x_s.shape[0] - assert k_tiles == x_s.shape[1] - - x_dq_block = x_q_block.to(torch.float32) - - x_dq_block_tiles = [[ - x_dq_block[ - j * block_n:min((j + 1) * block_n, n), - i * block_k:min((i + 1) * block_k, k), - ] for i in range(k_tiles) - ] for j in range(n_tiles)] - - for i in range(k_tiles): - for j in range(n_tiles): - x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] - + x_dq_block = scaled_dequantize(x_q_block, x_s) x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) return x_q_tensor, scale diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 83055d6000d83..95e785dcc4078 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,5 @@ """This file is used for /tests and /benchmarks""" -from typing import List, Optional +from typing import List, Optional, Tuple import numpy import torch @@ -20,6 +20,120 @@ FUSED_LAYER_NAME_MAPPING = { } +# Normalize the group_shape to the full extent for any dims that are -1 +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, + int]): + # -1 means full extent + return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], + group_shape[1] if group_shape[1] > 0 else x.shape[-1]) + + +# Useful when treating N-dimensional group scaling as extended numpy-style +# broadcasting in numpy simply stretches dimensions with an extent of 1 to match +# the target shape by repeating the data along that dimension (broadcasting) +# , we extend these semantics to say if the extent of a dimension in the +# source shape is not 1 and does not match the target shape we repeat each +# element along that dimension src_shape[dim] // target_shape[dim] times +# example if we have: +# a = [[1, 2], and target_shape = (2, 4) +# [3, 4]] +# then we would expand a to: +# a = [[1, 1, 2, 2], +# [3, 3, 4, 4]] +# NOTE this function this function does not explicitly broadcast dimensions +# with an extent of 1, since this can be done implicitly by pytorch +def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s % t.shape[i] == 0 + t = t.unsqueeze(i + 1)\ + .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ + .flatten(i, i + 1) + return t + + +# Quantize assuming once scale per group of elements with shape group_shape, +# example group shapes: +# * (-1, -1) for per-tensor quantization +# * (1, -1) for per-row quantization +# * (-1, 1) for per-column quantization +# * (128, 128) for 128x128 deepseek style block quantization +# * (1, 128) for deepseek style activation quantization +# (i.e. per-token-per-group) +def scaled_quantize( + x: torch.Tensor, + group_shape: Tuple[int, int], + quant_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + group_shape = _normalize_quant_group_shape(x, group_shape) + assert quant_dtype.is_floating_point, \ + "currently `scaled_quantize` only supports floating point dtypes " \ + "but could be extended to support other dtypes" + + finfo = torch.finfo(quant_dtype) + + # Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N) + assert x.ndim == 2 + assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0 + blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1] + x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1]) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + x_blkd_permd = x_blkd.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + x_blkd_permd = x_blkd_permd.flatten(start_dim=2) + + # Compute scales + min_val, max_val = x_blkd_permd.aminmax(dim=-1) + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + + # Apply scale and convert form: + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\ + .clamp(min=finfo.min, max=finfo.max)\ + .reshape(blk_m, blk_n, group_shape[0], group_shape[1])\ + .permute(0, 2, 1, 3)\ + .reshape(x.shape) + + return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal() + + +# inverses `scaled_quantize` +def scaled_dequantize( + x_q: torch.Tensor, + x_s: torch.Tensor, + group_shape: Optional[Tuple[int, int]] = None, + out_dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + if group_shape is not None: + group_shape = _normalize_quant_group_shape(x_q, group_shape) + + if x_s.ndim == 0: # scalar + x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor + if x_s.ndim == 1: + if group_shape is None: + raise AssertionError( + "if x_s is 1D tensor, group_shape must be provided otherwise " + "its ambiguous which dimension to broadcast x_s to") + # unsqueeze the scales for the dimension where we want to broadcast + # across the full extent + if group_shape[0] == x_q.shape[-2]: + x_s = x_s.unsqueeze(-2) + elif group_shape[1] == x_q.shape[-1]: + x_s = x_s.unsqueeze(-1) + else: + raise AssertionError( + "if x_s is a vector we should be broadcasting it to the full " + "extent of one of the dimensions") + + if group_shape is not None: + assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] + assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0] + x_s = group_broadcast(x_s.to(torch.float32), x_q.shape) + return (x_q.to(torch.float32) * x_s).to(out_dtype) + + def pack_quantized_values_into_int32(w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0): diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 62babcddd61b1..4be511d12838d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -398,11 +398,13 @@ class DefaultModelLoader(BaseModelLoader): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - elif isinstance(module, Attention) and \ + if isinstance(module, Attention) and \ hasattr(module, "process_weights_after_loading"): # When attention modules need to process weights after # currently only used by MLA - module.process_weights_after_loading() + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) return model.eval() @@ -439,6 +441,11 @@ class DummyModelLoader(BaseModelLoader): with device_loading_context( module, torch.device(device_config.device)): quant_method.process_weights_after_loading(module) + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # When attention modules need to process weights after + # currently only used by MLA + module.process_weights_after_loading(model_config.dtype) return model.eval() @@ -633,6 +640,12 @@ class ShardedStateLoader(BaseModelLoader): quant_method = getattr(module, "quant_method", None) if quant_method is not None: quant_method.process_weights_after_loading(module) + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # When attention modules need to process weights after + # currently only used by MLA + module.process_weights_after_loading( + model_config.dtype) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path, @@ -1272,7 +1285,7 @@ class GGUFModelLoader(BaseModelLoader): class RunaiModelStreamerLoader(BaseModelLoader): """ - Model loader that can load safetensors + Model loader that can load safetensors files from local FS or S3 bucket. """ @@ -1369,6 +1382,11 @@ class RunaiModelStreamerLoader(BaseModelLoader): if quant_method is not None: with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # When attention modules need to process weights after + # currently only used by MLA + module.process_weights_after_loading(model_config.dtype) return model.eval() diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py index 0b44f0d062c40..f6ab53c85faa3 100644 --- a/vllm/model_executor/models/deepseek_v3.py +++ b/vllm/model_executor/models/deepseek_v3.py @@ -27,7 +27,7 @@ from torch import nn from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -333,12 +333,156 @@ class DeepseekV3Attention(nn.Module): return output +class DeepseekV3MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py + """ + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + else: + hidden_states_or_q_c = hidden_states + kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache, + attn_metadata) + + class DeepseekV3DecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, prefix: str, + model_config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -351,7 +495,11 @@ class DeepseekV3DecoderLayer(nn.Module): # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) - self.self_attn = DeepseekV3Attention( + if model_config.use_mla: + attn_cls = DeepseekV3MLAAttention + else: + attn_cls = DeepseekV3Attention + self.self_attn = attn_cls( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -428,6 +576,7 @@ class DeepseekV3Model(nn.Module): super().__init__() config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -447,6 +596,7 @@ class DeepseekV3Model(nn.Module): lambda prefix: DeepseekV3DecoderLayer( config, prefix, + model_config=model_config, cache_config=cache_config, quant_config=quant_config, ), diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 08316ba74aad8..c427b759b2e97 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -110,7 +110,9 @@ class CacheEngine: parallel_config, LayerBlockType.attention) key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block + # For MLA there is no value cache, since the latent vector + # is joint keys and values. + value_cache_block = key_cache_block if not model_config.use_mla else 0 total = num_attention_layers * (key_cache_block + value_cache_block) if cache_config.cache_dtype == "auto": dtype = model_config.dtype