[perf] Add fused MLA QKV + strided layernorm (#21116)

Signed-off-by: Mickael Seznec <mickael@mistral.ai>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Mickaël Seznec 2025-07-22 16:07:44 +02:00 committed by GitHub
parent 0df4d9b06b
commit 4fb56914c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 214 additions and 66 deletions

View File

@ -17,13 +17,14 @@ template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
const float x = (float)input[blockIdx.x * input_stride + idx];
variance += x * x;
}
@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x = (float)input[blockIdx.x * input_stride + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
@ -51,6 +52,7 @@ template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
const int64_t vec_input_stride = input_stride / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = input_v[strided_id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
input_v[strided_id] = temp;
}
}
@ -104,6 +109,7 @@ template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
const int64_t input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
@ -111,7 +117,7 @@ fused_add_rms_norm_kernel(
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
scalar_t z = input[blockIdx.x * input_stride + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] =
input[blockIdx.x * input_stride + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(input.stride(-1) == 1);
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
@ -153,7 +160,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
@ -162,17 +169,20 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
<<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), input_stride, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
});
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int64_t input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
constexpr int vector_width = 8;
constexpr int req_alignment_bytes =
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
res_ptr % req_alignment_bytes == 0 &&
wt_ptr % req_alignment_bytes == 0;
bool offsets_are_multiple_of_vector_width =
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);

View File

@ -25,6 +25,7 @@ template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const int input_stride,
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
const float x = (float)input[blockIdx.x * input_stride + idx];
variance += x * x;
}
@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x = (float)input[blockIdx.x * input_stride + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
@ -65,6 +66,7 @@ __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
const int input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
const int vec_input_stride = input_stride / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int stride_id = blockIdx.x * vec_input_stride + idx;
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
_f16Vec<scalar_t, width> temp = input_v[stride_id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
@ -127,6 +131,7 @@ __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
const int input_stride,
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
scalar_t z = input[blockIdx.x * input_stride + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
epsilon, num_tokens, hidden_size);
input_stride, weight.data_ptr<scalar_t>(),
scale.data_ptr<float>(), epsilon, num_tokens,
hidden_size);
});
});
}
@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
width, fp8_t> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
input_stride, residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant(
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);

View File

@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();
@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const block_size = 256;
int const num_tokens = input.numel() / input.size(-1);
int const num_elems = input.numel();

View File

@ -26,6 +26,7 @@ CUDA_DEVICES = [
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
@ -34,13 +35,17 @@ def test_rms_norm(
dtype: torch.dtype,
seed: int,
device: str,
strided_input: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
last_dim = 2 * hidden_size if strided_input else hidden_size
x = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
@ -72,6 +77,7 @@ def test_rms_norm(
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("strided_input", [False, True])
def test_fused_rms_norm_quant(
num_tokens: int,
hidden_size: int,
@ -80,13 +86,18 @@ def test_fused_rms_norm_quant(
quant_scale: float,
seed: int,
device: str,
strided_input: bool,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
last_dim = 2 * hidden_size if strided_input else hidden_size
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
x = x_base[..., :hidden_size]
assert x.is_contiguous() != strided_input
x *= scale
if add_residual:
residual = torch.randn_like(x) * scale
@ -106,9 +117,11 @@ def test_fused_rms_norm_quant(
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused = x.clone()
x_unfused_base = x_base.clone()
x_unfused = x_unfused_base[..., :hidden_size]
assert x_unfused.is_contiguous() != strided_input
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
quant_scale_t)
torch.cuda.synchronize()
@ -116,7 +129,6 @@ def test_fused_rms_norm_quant(
residual,
atol=1e-2,
rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
@ -131,7 +143,7 @@ def test_fused_rms_norm_quant(
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
out_quant.to(dtype=torch.float32),
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
out_quant_fused.to(dtype=torch.float32),
atol=1e-3,
rtol=1e-3)

View File

@ -259,6 +259,8 @@ class LinearBase(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.quant_config = quant_config
self.prefix = prefix
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
@ -300,6 +302,12 @@ class ReplicatedLinear(LinearBase):
*,
return_bias: bool = True,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
@ -311,7 +319,8 @@ class ReplicatedLinear(LinearBase):
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_partition_sizes,
self.input_size,
self.output_size,
self.params_dtype,
@ -367,6 +376,73 @@ class ReplicatedLinear(LinearBase):
return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
param[shard_offset:shard_offset + shard_size] = loaded_weight
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.

View File

@ -257,9 +257,16 @@ class Fp8LinearMethod(LinearMethodBase):
f"{input_size_per_partition} is not divisible by "
f"weight quantization block_k = {block_k}.")
# Required by column parallel or enabling merged weights
if (tp_size > 1 and output_size // output_size_per_partition
== tp_size) or len(output_partition_sizes) > 1:
for output_partition_size in output_partition_sizes:
is_tp_split = (tp_size > 1 and
output_size // output_size_per_partition == tp_size)
is_merged_gemm = len(output_partition_sizes) > 1
if is_tp_split or is_merged_gemm:
sizes_to_check = output_partition_sizes
if not is_tp_split and is_merged_gemm:
# In case of merged matrices, we allow the last
# matrix to not be a multiple of block size
sizes_to_check = output_partition_sizes[:-1]
for output_partition_size in sizes_to_check:
if output_partition_size % block_n != 0:
raise ValueError(
f"Weight output_partition_size = "

View File

@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -336,7 +337,7 @@ class DeepseekV2Attention(nn.Module):
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous())
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads,
self.qk_nope_head_dim + self.v_head_dim)
@ -407,14 +408,24 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
self.fused_qkv_a_proj = MergedReplicatedLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
prefix=f"{prefix}.fused_qkv_a_proj")
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
@ -427,13 +438,6 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
@ -495,15 +499,24 @@ class DeepseekV2MLAAttention(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
@ -837,6 +850,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
@ -871,6 +886,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj")
and name not in params_dict):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue