diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index fc5f3420e394d..ff411f75ae7ff 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -7,22 +7,22 @@ First we define: Sq as Q sequence length Skv as KV sequence length -MLA has two possible ways of computing, a data-movement friendly approach and a -compute friendly approach, we generally want to use the compute friendly -approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) -and the data-movement friendly approach for "decode" (i.e. the ratio -Sq / Skv is "large"). +MLA has two possible ways of computing, a data-movement friendly approach and a +compute friendly approach, we generally want to use the compute friendly +approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) +and the data-movement friendly approach for "decode" (i.e. the ratio +Sq / Skv is "large"). -NOTE what we deem small and large is currently determined by if its labelled -prefill or decode by the scheduler, but this is something we should probably +NOTE what we deem small and large is currently determined by if its labelled +prefill or decode by the scheduler, but this is something we should probably tune. Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. -* For decode (i.e. the memory friendly approach) the attention "simulates" a +* Use a single latent vector to represent the per-token entry of the KV cache. +* For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. Below is example of both paths assuming batchsize = 1 @@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq] W_UQ project q_c to q_nope shape [Lq, N * P] W_QR project q_c to q_pe shape [Lq, N * R] W_DKV project h_t to kv_c shape [H, Lkv] -W_UK project kv_c to k_nope shape [Lkv, N * P] -W_KR project h_t to k_pe shape [H, N * R] -W_UV project kv_c to v shape [Lkv, N * V] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] W_O project v to h_t shape [N * V, H] @@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) -k_nope = (kv_c @ W_UK).view(Skv, N, P) -v = (kv_c @ W_UV).view(Skv, N, V) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) // MHA with QK headdim = P + R // V headdim = V @@ -90,20 +90,10 @@ NOTE: in the actual code, ## Data-Movement Friendly Approach (i.e. "_forward_decode"): -Ahead of time, compute: - -% this projects from q_c to [Sq, N * Lkv] -W_UQ_UK = einsum("qnp,knp -> qnk" - W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P) - ).view(Lkv, N * Lkv) -% this projects from attn output [Sq, N * Lkv] to [Sq, H] -W_UV_O = einsum("knv,nvh -> nkh" - W_UV.view(Lkv, N, V), W_O.view(N, V, H) - ).view(N * Lkv, H) - Runtime q_c = h_t @ W_DQ -q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) @@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) // NOTE: this is less compute-friendly since Lkv > P // but is more data-movement friendly since its MQA vs MHA spda_o = scaled_dot_product_attention( - torch.cat([q_latent, q_pe], dim=-1), + torch.cat([ql_nope, q_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1), kv_c ) -return spda_o.reshape(-1, N * Lkv) @ W_UV_O + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill @@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) -new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) -new_v = (new_kv_c @ W_UV).view(Sq, N, V) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) // MHA between queries and new KV // with QK headdim = P + R @@ -171,17 +163,17 @@ for chunk_idx in range(cdiv(C, MCC)): cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) - + chunk_o, chunk_lse = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), - torch.cat([cache_k_nope_chunk, - cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], + torch.cat([cache_k_nope_chunk, + cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], dim=-1), cache_v_chunk, casual=False, return_softmax_lse=True ) - + curr_o, curr_lse = merge_attn_states( suffix_output=curr_o, suffix_lse=curr_lse, @@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar) import torch -from compressed_tensors.quantization import QuantizationStrategy from vllm import _custom_ops as ops from vllm import envs @@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, get_flash_attn_version, is_block_tables_empty) from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW8A8Fp8) -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - Fp8LinearGenericOp, is_fp8) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap @@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.triton_fa_func = triton_attention - self.fp8_linear_generic = Fp8LinearGenericOp() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -1070,80 +1049,29 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): fa_version=self.vllm_flash_attn_version) def _v_up_proj_and_o_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_UV_O): - output_parallel = self.fp8_linear_generic.apply( - x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape) - else: - output_parallel = torch.matmul(x.flatten(start_dim=1), - self.W_UV_O) - if self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - return output - else: - x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) - return self.o_proj(x.reshape(-1, - self.num_heads * self.v_head_dim))[0] + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return self.o_proj(x)[0] + # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_Q_UK): - return self.fp8_linear_generic.apply( - x, self.W_Q_UK, self.W_Q_UK_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape).view( - -1, self.num_heads, self.kv_lora_rank) - return torch.matmul(x, self.W_Q_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) - else: - x = torch.matmul(x, self.W_Q)\ - .view(-1, self.num_heads, self.qk_nope_head_dim) - return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe def process_weights_after_loading(self, act_dtype: torch.dtype): - # TODO(lucas) This is very gross, we need a more wide scale refactor of - # all the FP8 code with a more standard way of - # defining schemes/group-shapes, we should also potentially force - # quant_methods to support a decompress function - # - # returns input_group_shape, weight_group_shape - def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ - Tuple[Tuple[int, int], Tuple[int, int]]: - if isinstance(layer.quant_method, Fp8LinearMethod): - if layer.quant_method.block_quant: - weight_block_size = \ - layer.quant_method.quant_config.weight_block_size - # per-token-group (1, X), block-quantized (X, Y) - return (1, weight_block_size[-1]), weight_block_size - else: - return (-1, -1), (-1, -1) # per-tensor, per-tensor - elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ - and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): - # this is hacky but we always assume the for - # CompressedTensorsW8A8Fp8 the input is dynamic per-token - # we ignore if it is static-per-tensor since we are going to - # requantize after later anyways - strategy = layer.scheme.strategy - if strategy == QuantizationStrategy.TENSOR: - return (1, -1), (-1, -1) # per-token, per-tensor - elif strategy == QuantizationStrategy.CHANNEL: - return (1, -1), (-1, 1) # per-token, per-channel - else: - raise NotImplementedError( - f"QuantizationStrategy.{strategy} is not supported for " - "fp8 MLA, please run with VLLM_MLA_DISABLE=1") - else: - raise NotImplementedError( - "Can't determine scale group shapes for " - f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" - ) - def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: @@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return dequant_weights.T return layer.weight - weight_dtype = get_layer_weight(self.kv_b_proj).dtype - assert get_layer_weight(self.o_proj).dtype == weight_dtype - assert get_layer_weight(self.q_proj).dtype == weight_dtype - + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, @@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ - .view(-1, self.num_heads, self.qk_head_dim) - - # can be W_Q or W_UQ depending q_lora_rank, the former if - # q_lora_rank is None, the latter otherwise. From the Attention backend - # perspective though we call these both W_Q and rely on the layer - # to pass in the correct matrix - W_Q = q_proj_weight[..., :self.qk_nope_head_dim] - self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ - .flatten(start_dim=1).contiguous() - - # W_QR is small so for simplicity we dont bother requantizing it - self.W_QR = self.W_QR.to(act_dtype) - - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION - if is_fp8(weight_dtype) and requantization_enabled: - # This assumes it wise to requantize using the same group shapes - # (i.e. strategy, per-tensor, per-channel, block etc.) that the - # weights were originally quantized - requant_input_group_shape, requant_weight_group_shape = \ - get_scale_group_shapes_for_fp8(self.q_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.kv_b_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.o_proj) - self.reqaunt_input_group_shape = requant_input_group_shape - self.reqaunt_weight_group_shape = requant_weight_group_shape - - # - # Perform matrix-absorption following - # https://github.com/flashinfer-ai/flashinfer/pull/551 - # for decode, as a result we end up with absorbed weights for decode - # and another copy of raw weights for prefill. - # - self.W_UK, self.W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK - # depending q_lora_rank, the former if q_lora_rank is None, the - # latter otherwise - # basically if q_lora_rank is none we are absorbing into q_proj - # instead of UQ - W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ - .flatten(start_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_Q_UK, W_Q_UK_scales = scaled_quantize( - W_Q_UK, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform.fp8_dtype()) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_Q_UK = W_Q_UK.T.contiguous() - self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() - else: - self.W_Q_UK = W_Q_UK.to(act_dtype) - - W_O = get_and_maybe_dequant_weights(self.o_proj)\ - .view(-1, self.num_heads, self.v_head_dim) - W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ - .flatten(start_dim=0, end_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_UV_O, W_UV_O_scales = scaled_quantize( - W_UV_O, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform.fp8_dtype()) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_UV_O = W_UV_O.T.contiguous() - self.W_UV_O_scales = W_UV_O_scales.T.contiguous() - else: - self.W_UV_O = W_UV_O.to(act_dtype) - - self.tp_size = get_tensor_model_parallel_world_size() - else: - if is_fp8(weight_dtype): - raise NotImplementedError( - "Currently fp8 requires matrix absorption") - - self.W_UV = W_UV - self.W_UK = W_UK - self.W_Q = W_Q.flatten(start_dim=1) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, @@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): @abstractmethod def _forward_decode( self, - q_nope: torch.Tensor, + ql_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: T, @@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): prefill_k_c_normed = k_c_normed[:num_prefill_tokens] if has_decode: - decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) - decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ - .view(-1, self.num_heads, self.qk_rope_head_dim) + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_input_positions, decode_q_pe, decode_k_pe) @@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): if has_decode: output[num_prefill_tokens:] = self._forward_decode( - decode_q_nope, decode_q_pe, kv_cache, attn_metadata) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) return output diff --git a/vllm/envs.py b/vllm/envs.py index 259501056cc3b..a36d20a4f8b50 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,8 +84,6 @@ if TYPE_CHECKING: VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False - VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True - VLLM_MLA_DISABLE_REQUANTIZATION: bool = False VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_RAY_PER_WORKER_GPUS: float = 1.0 @@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_MLA_DISABLE": lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), - # Flag that can control whether or not we perform matrix-absorption for MLA - # decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the - # matrices reduces the runtime FLOPs needed to compute MLA but requires - # storing more weights, W_Q_UK and W_UV_O, so can increase memory usage, - # the is enabled by default - "VLLM_MLA_PERFORM_MATRIX_ABSORPTION": - lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))), - - # When running MLA with matrix-absorption enabled and fp8 quantized weights - # we perform the matrix-absorption in float32 precision, after the matrices - # are absorbed we requantize the weights back to fp8, this flag can be used - # to disable the requantization step, and instead convert the absorbed - # matrices to match the activation type. This can lead to higher memory and - # compute usage but better preserves the accuracy of the original model. - "VLLM_MLA_DISABLE_REQUANTIZATION": - lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))), - # If set, vLLM will use the Triton implementation of moe_align_block_size, # i.e. moe_align_block_size_triton in fused_moe.py. "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 1e19302cbad81..ecb7996e1e8c5 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -13,10 +13,9 @@ import triton.language as tl from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - _normalize_quant_group_shape, scaled_dequantize) + scaled_dequantize) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported, - cutlass_fp8_supported) + CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -101,60 +100,6 @@ direct_register_custom_op( ) -# Unify the interface between `apply_w8a8_block_fp8_linear` and -# `apply_fp8_linear` -# NOTE(lucas): this is quite messy, we should think through this more formally -# TODO(luka): unify this better -# https://github.com/vllm-project/vllm/issues/14397 -class Fp8LinearGenericOp: - - def __init__( - self, - cutlass_fp8_supported: bool = cutlass_fp8_supported(), - cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(), - ): - self.cutlass_block_fp8_supported = cutlass_block_fp8_supported - self.fp8_linear = Fp8LinearOp( - cutlass_fp8_supported=cutlass_fp8_supported) - - def apply( - self, - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - input_group_shape: Tuple[int, int], - weight_group_shape: Tuple[int, int], - input_scale: Optional[torch.Tensor] = None, # static scale if one - ) -> torch.Tensor: - # View input as 2D matrix for fp8 methods - input = input.view(-1, input.shape[-1]) - - weight_group_shape = _normalize_quant_group_shape( \ - weight, weight_group_shape) - input_group_shape = _normalize_quant_group_shape( - input, input_group_shape) - - def is_dim_blocked(dim, shape, group_shape): - return group_shape < shape[dim] and group_shape > 1 - - if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ - and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ - input_group_shape == (1, weight_group_shape[1]): - return apply_w8a8_block_fp8_linear( - input, - weight, - list(weight_group_shape), - weight_scale, - cutlass_block_fp8_supported=self.cutlass_block_fp8_supported) - else: - # Despite having linear in the name it doesn't conform to - # `torch.nn.functional.linear` which is defined as - # `input @ weight.T` so we explicitly transpose the weight matrix - return self.fp8_linear.apply(input, weight.T, weight_scale.T, - use_per_token_if_dynamic=\ - (input_group_shape == (1, input.shape[1]))) - - def input_to_float8( x: torch.Tensor, dtype: Optional[torch.dtype] = None diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 14a7bd3535222..f801745ab5c7d 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -21,7 +21,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: -* Use a single latent vector to represent the per-token entry of the KV cache. +* Use a single latent vector to represent the per-token entry of the KV cache. * For decode (i.e. the memory friendly approach) the attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. @@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq] W_UQ project q_c to q_nope shape [Lq, N * P] W_QR project q_c to q_pe shape [Lq, N * R] W_DKV project h_t to kv_c shape [H, Lkv] -W_UK project kv_c to k_nope shape [Lkv, N * P] -W_KR project h_t to k_pe shape [H, N * R] -W_UV project kv_c to v shape [Lkv, N * V] +W_UK project kv_c to k_nope shape [Lkv, N, P] +W_KR project h_t to k_pe shape [H, R] +W_UV project kv_c to v shape [Lkv, N, V] W_O project v to h_t shape [N * V, H] @@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) -k_nope = (kv_c @ W_UK).view(Skv, N, P) -v = (kv_c @ W_UV).view(Skv, N, V) +k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P) +v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V) // MHA with QK headdim = P + R // V headdim = V @@ -79,7 +79,7 @@ spda_o = scaled_dot_product_attention( torch.cat([q_nope, q_pe], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), v -) +) return spda_o @ W_O NOTE: in the actual code, @@ -90,20 +90,10 @@ NOTE: in the actual code, ## Data-Movement Friendly Approach (i.e. "_forward_decode"): -Ahead of time, compute: - -% this projects from q_c to [Sq, N * Lkv] -W_UQ_UK = einsum("qnp,knp -> qnk" - W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P) - ).view(Lkv, N * Lkv) -% this projects from attn output [Sq, N * Lkv] to [Sq, H] -W_UV_O = einsum("knv,nvh -> nkh" - W_UV.view(Lkv, N, V), W_O.view(N, V, H) - ).view(N * Lkv, H) - Runtime q_c = h_t @ W_DQ -q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) +q_nope = (q_c @ W_UQ).view(-1, N, P) +ql_nope = einsum("snh,lnh->snl", q, W_UK) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) @@ -116,29 +106,31 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) // NOTE: this is less compute-friendly since Lkv > P // but is more data-movement friendly since its MQA vs MHA spda_o = scaled_dot_product_attention( - torch.cat([q_latent, q_pe], dim=-1), + torch.cat([ql_nope, q_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1), kv_c ) -return spda_o.reshape(-1, N * Lkv) @ W_UV_O + +o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV) +return o.view(-1, N * V) @ self.num_heads @ W_O ## Chunked Prefill -For chunked prefill we want to use the compute friendly algorithm. We are -assuming sufficiently large Sq / Skv ratio, in the future may want to switch to +For chunked prefill we want to use the compute friendly algorithm. We are +assuming sufficiently large Sq / Skv ratio, in the future may want to switch to the data-movement friendly approach if the chunk (i.e. `Sq`) is small. However, the compute-friendly approach can potentially run out of memory if Skv is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` -To mitigate this, we chunk the computation of attention with respect to the -current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a +To mitigate this, we chunk the computation of attention with respect to the +current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a fixed workspace size. The chunked prefill approach is as follows: -MCC Max chunk of context to process per iter, computed dynamically, +MCC Max chunk of context to process per iter, computed dynamically, used to bound the memory usage q_c = h_t @ W_DQ @@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) new_kv_c = h_t @ W_DKV new_k_pe = RoPE(h_t @ W_KR) -new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) -new_v = (new_kv_c @ W_UV).view(Sq, N, V) +new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P) +new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V) // MHA between queries and new KV // with QK headdim = P + R @@ -160,7 +152,7 @@ curr_o, curr_lse = scaled_dot_product_attention( new_v, casual=True, return_softmax_lse=True -) +) // Compute attention with the already existing context for chunk_idx in range(cdiv(C, MCC)): @@ -198,30 +190,17 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import torch -from compressed_tensors.quantization import QuantizationStrategy from vllm import _custom_ops as ops -from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, MLAAttentionImpl) from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.ops.triton_merge_attn_states import merge_attn_states -from vllm.distributed import (get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsW8A8Fp8) -from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - Fp8LinearGenericOp, is_fp8) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down @@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() - self.fp8_linear_generic = Fp8LinearGenericOp() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the @@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): fa_version=self.vllm_flash_attn_version) def _v_up_proj_and_o_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_UV_O): - output_parallel = self.fp8_linear_generic.apply( - x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape) - else: - output_parallel = torch.matmul(x.flatten(start_dim=1), - self.W_UV_O) - if self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - return output - else: - x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) - return self.o_proj(x.reshape(-1, - self.num_heads * self.v_head_dim))[0] + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return self.o_proj(x)[0] + # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - if is_fp8(self.W_Q_UK): - return self.fp8_linear_generic.apply( - x, self.W_Q_UK, self.W_Q_UK_scales, - self.reqaunt_input_group_shape, - self.reqaunt_weight_group_shape).view( - -1, self.num_heads, self.kv_lora_rank) - return torch.matmul(x, self.W_Q_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) - else: - x = torch.matmul(x, self.W_Q)\ - .view(-1, self.num_heads, self.qk_nope_head_dim) - return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ - .view(-1, self.num_heads, self.kv_lora_rank) + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe def process_weights_after_loading(self, act_dtype: torch.dtype): - # TODO(lucas) This is very gross, we need a more wide scale refactor of - # all the FP8 code with a more standard way of - # defining schemes/group-shapes, we should also potentially force - # quant_methods to support a decompress function - # - # returns input_group_shape, weight_group_shape - def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ - tuple[tuple[int, int], tuple[int, int]]: - if isinstance(layer.quant_method, Fp8LinearMethod): - if layer.quant_method.block_quant: - weight_block_size = \ - layer.quant_method.quant_config.weight_block_size - # per-token-group (1, X), block-quantized (X, Y) - return (1, weight_block_size[-1]), weight_block_size - else: - return (-1, -1), (-1, -1) # per-tensor, per-tensor - elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ - and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): - # this is hacky but we always assume the for - # CompressedTensorsW8A8Fp8 the input is dynamic per-token - # we ignore if it is static-per-tensor since we are going to - # requantize after later anyways - strategy = layer.scheme.strategy - if strategy == QuantizationStrategy.TENSOR: - return (1, -1), (-1, -1) # per-token, per-tensor - elif strategy == QuantizationStrategy.CHANNEL: - return (1, -1), (-1, 1) # per-token, per-channel - else: - raise NotImplementedError( - f"QuantizationStrategy.{strategy} is not supported for " - "fp8 MLA, please run with VLLM_MLA_DISABLE=1") - else: - raise NotImplementedError( - "Can't determine scale group shapes for " - f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" - ) - def get_layer_weight(layer): - if hasattr(layer, "weight"): - return layer.weight - elif hasattr(layer, "qweight"): - return layer.qweight - else: - raise AttributeError( - f"Layer '{layer}' has neither weight nor qweight") + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): @@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): return dequant_weights.T return layer.weight - weight_dtype = get_layer_weight(self.kv_b_proj).dtype - assert get_layer_weight(self.o_proj).dtype == weight_dtype - assert get_layer_weight(self.q_proj).dtype == weight_dtype - + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, @@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ - .view(-1, self.num_heads, self.qk_head_dim) - - # can be W_Q or W_UQ depending q_lora_rank, the former if - # q_lora_rank is None, the latter otherwise. From the Attention backend - # perspective though we call these both W_Q and rely on the layer - # to pass in the correct matrix - W_Q = q_proj_weight[..., :self.qk_nope_head_dim] - self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ - .flatten(start_dim=1).contiguous() - - # W_QR is small so for simplicity we dont bother requantizing it - self.W_QR = self.W_QR.to(act_dtype) - - if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: - requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION - if is_fp8(weight_dtype) and requantization_enabled: - # This assumes it wise to requantize using the same group shapes - # (i.e. strategy, per-tensor, per-channel, block etc.) that the - # weights were originally quantized - requant_input_group_shape, requant_weight_group_shape = \ - get_scale_group_shapes_for_fp8(self.q_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.kv_b_proj) - assert (requant_input_group_shape, requant_weight_group_shape)\ - == get_scale_group_shapes_for_fp8(self.o_proj) - self.reqaunt_input_group_shape = requant_input_group_shape - self.reqaunt_weight_group_shape = requant_weight_group_shape - - # - # Perform matrix-absorption following - # https://github.com/flashinfer-ai/flashinfer/pull/551 - # for decode, as a result we end up with absorbed weights for decode - # and another copy of raw weights for prefill. - # - self.W_UK, self.W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK - # depending q_lora_rank, the former if q_lora_rank is None, the - # latter otherwise - # basically if q_lora_rank is none we are absorbing into q_proj - # instead of UQ - W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ - .flatten(start_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_Q_UK, W_Q_UK_scales = scaled_quantize( - W_Q_UK, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform.fp8_dtype()) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_Q_UK = W_Q_UK.T.contiguous() - self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() - else: - self.W_Q_UK = W_Q_UK.to(act_dtype) - - W_O = get_and_maybe_dequant_weights(self.o_proj)\ - .view(-1, self.num_heads, self.v_head_dim) - W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ - .flatten(start_dim=0, end_dim=1).contiguous() - - if is_fp8(weight_dtype) and requantization_enabled: - W_UV_O, W_UV_O_scales = scaled_quantize( - W_UV_O, - self.reqaunt_weight_group_shape, - quant_dtype=current_platform.fp8_dtype()) - # For FP8 save the transpose so we can use - # `apply_w8a8_block_fp8_linear` directly - self.W_UV_O = W_UV_O.T.contiguous() - self.W_UV_O_scales = W_UV_O_scales.T.contiguous() - else: - self.W_UV_O = W_UV_O.to(act_dtype) - - self.tp_size = get_tensor_model_parallel_world_size() - else: - if is_fp8(weight_dtype): - raise NotImplementedError( - "Currently fp8 requires matrix absorption") - - self.W_UV = W_UV - self.W_UK = W_UK - self.W_Q = W_Q.flatten(start_dim=1) + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) def _compute_prefill_context( self, @@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): @abstractmethod def _forward_decode( self, - q_nope: torch.Tensor, + ql_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, @@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_decode: assert attn_metadata.decode is not None - decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) - decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ - .view(-1, self.num_heads, self.qk_rope_head_dim) - + decode_ql_nope, decode_q_pe = \ + self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( attn_metadata.decode.input_positions, decode_q_pe.contiguous(), decode_k_pe) @@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): if has_decode: output[:num_decode_tokens] = self._forward_decode( - decode_q_nope, decode_q_pe, kv_cache, attn_metadata) + decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) return output_padded