mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:05:52 +08:00
[perf] Add fused MLA QKV + strided layernorm (#21116)
Signed-off-by: Mickael Seznec <mickael@mistral.ai> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
0df4d9b06b
commit
4fb56914c5
@ -15,15 +15,16 @@ namespace vllm {
|
|||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void rms_norm_kernel(
|
__global__ void rms_norm_kernel(
|
||||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const int64_t input_stride,
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||||
out[blockIdx.x * hidden_size + idx] =
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
((scalar_t)(x * s_variance)) * weight[idx];
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
|
|||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const int64_t input_stride,
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
|
|||||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
|
|
||||||
const int vec_hidden_size = hidden_size / width;
|
const int vec_hidden_size = hidden_size / width;
|
||||||
|
const int64_t vec_input_stride = input_stride / width;
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
/* These and the argument pointers are all declared `restrict` as they are
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
int id = blockIdx.x * vec_hidden_size + idx;
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = input_v[strided_id];
|
||||||
temp += residual_v[id];
|
temp += residual_v[id];
|
||||||
variance += temp.sum_squares();
|
variance += temp.sum_squares();
|
||||||
residual_v[id] = temp;
|
residual_v[id] = temp;
|
||||||
@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
int id = blockIdx.x * vec_hidden_size + idx;
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||||
temp *= s_variance;
|
temp *= s_variance;
|
||||||
temp *= weight_v[idx];
|
temp *= weight_v[idx];
|
||||||
input_v[id] = temp;
|
input_v[strided_id] = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -103,7 +108,8 @@ fused_add_rms_norm_kernel(
|
|||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width>
|
||||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
fused_add_rms_norm_kernel(
|
fused_add_rms_norm_kernel(
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const int64_t input_stride,
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
@ -111,7 +117,7 @@ fused_add_rms_norm_kernel(
|
|||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
scalar_t z = input[blockIdx.x * input_stride + idx];
|
||||||
z += residual[blockIdx.x * hidden_size + idx];
|
z += residual[blockIdx.x * hidden_size + idx];
|
||||||
float x = (float)z;
|
float x = (float)z;
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
|
|||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||||
input[blockIdx.x * hidden_size + idx] =
|
input[blockIdx.x * input_stride + idx] =
|
||||||
((scalar_t)(x * s_variance)) * weight[idx];
|
((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
double epsilon) {
|
double epsilon) {
|
||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
TORCH_CHECK(input.is_contiguous());
|
TORCH_CHECK(input.stride(-1) == 1);
|
||||||
TORCH_CHECK(weight.is_contiguous());
|
TORCH_CHECK(weight.is_contiguous());
|
||||||
|
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
int64_t input_stride = input.stride(-2);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
|
||||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
<<<grid, block, 0, stream>>>( \
|
||||||
residual.data_ptr<scalar_t>(), \
|
input.data_ptr<scalar_t>(), input_stride, \
|
||||||
weight.data_ptr<scalar_t>(), epsilon, \
|
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||||
num_tokens, hidden_size); \
|
epsilon, num_tokens, hidden_size); \
|
||||||
});
|
});
|
||||||
|
|
||||||
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& residual, // [..., hidden_size]
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
double epsilon) {
|
double epsilon) {
|
||||||
|
TORCH_CHECK(residual.is_contiguous());
|
||||||
|
TORCH_CHECK(weight.is_contiguous());
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
|
int64_t input_stride = input.stride(-2);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned =
|
constexpr int vector_width = 8;
|
||||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
constexpr int req_alignment_bytes =
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
|
||||||
|
// falls back to non-vectorized version anyway)
|
||||||
|
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
|
||||||
|
res_ptr % req_alignment_bytes == 0 &&
|
||||||
|
wt_ptr % req_alignment_bytes == 0;
|
||||||
|
bool offsets_are_multiple_of_vector_width =
|
||||||
|
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||||
|
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
|||||||
@ -23,8 +23,9 @@ namespace vllm {
|
|||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template <typename scalar_t, typename fp8_type>
|
template <typename scalar_t, typename fp8_type>
|
||||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const int input_stride,
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float* __restrict__ scale, // [1]
|
const float* __restrict__ scale, // [1]
|
||||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
|||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
|||||||
float const scale_inv = 1.0f / *scale;
|
float const scale_inv = 1.0f / *scale;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
out[blockIdx.x * hidden_size + idx] =
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||||
@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
|||||||
template <typename scalar_t, int width, typename fp8_type>
|
template <typename scalar_t, int width, typename fp8_type>
|
||||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const int input_stride,
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float* __restrict__ scale, // [1]
|
const float* __restrict__ scale, // [1]
|
||||||
@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
|
|
||||||
const int vec_hidden_size = hidden_size / width;
|
const int vec_hidden_size = hidden_size / width;
|
||||||
|
const int vec_input_stride = input_stride / width;
|
||||||
__shared__ float s_variance;
|
__shared__ float s_variance;
|
||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
/* These and the argument pointers are all declared `restrict` as they are
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int stride_id = blockIdx.x * vec_input_stride + idx;
|
||||||
int id = blockIdx.x * vec_hidden_size + idx;
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
_f16Vec<scalar_t, width> temp = input_v[stride_id];
|
||||||
temp += residual_v[id];
|
temp += residual_v[id];
|
||||||
variance += temp.sum_squares();
|
variance += temp.sum_squares();
|
||||||
residual_v[id] = temp;
|
residual_v[id] = temp;
|
||||||
@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
template <typename scalar_t, int width, typename fp8_type>
|
template <typename scalar_t, int width, typename fp8_type>
|
||||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const int input_stride,
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float* __restrict__ scale, // [1]
|
const float* __restrict__ scale, // [1]
|
||||||
@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
float variance = 0.0f;
|
float variance = 0.0f;
|
||||||
|
|
||||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
scalar_t z = input[blockIdx.x * input_stride + idx];
|
||||||
z += residual[blockIdx.x * hidden_size + idx];
|
z += residual[blockIdx.x * hidden_size + idx];
|
||||||
float x = (float)z;
|
float x = (float)z;
|
||||||
variance += x * x;
|
variance += x * x;
|
||||||
@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
torch::Tensor& scale, // [1]
|
torch::Tensor& scale, // [1]
|
||||||
double epsilon) {
|
double epsilon) {
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
|
int input_stride = input.stride(-2);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
input_stride, weight.data_ptr<scalar_t>(),
|
||||||
epsilon, num_tokens, hidden_size);
|
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||||
|
hidden_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
width, fp8_t> \
|
width, fp8_t> \
|
||||||
<<<grid, block, 0, stream>>>( \
|
<<<grid, block, 0, stream>>>( \
|
||||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||||
residual.data_ptr<scalar_t>(), \
|
input_stride, residual.data_ptr<scalar_t>(), \
|
||||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||||
epsilon, num_tokens, hidden_size); \
|
epsilon, num_tokens, hidden_size); \
|
||||||
}); \
|
}); \
|
||||||
@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant(
|
|||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
torch::Tensor& scale, // [1]
|
torch::Tensor& scale, // [1]
|
||||||
double epsilon) {
|
double epsilon) {
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
TORCH_CHECK(residual.is_contiguous());
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
|
int input_stride = input.stride(-2);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant(
|
|||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned =
|
bool ptrs_are_aligned =
|
||||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
|||||||
@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
torch::Tensor const& input, // [..., d]
|
torch::Tensor const& input, // [..., d]
|
||||||
torch::Tensor const& scale) // [1]
|
torch::Tensor const& scale) // [1]
|
||||||
{
|
{
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
int const block_size = 256;
|
int const block_size = 256;
|
||||||
int const num_tokens = input.numel() / input.size(-1);
|
int const num_tokens = input.numel() / input.size(-1);
|
||||||
int const num_elems = input.numel();
|
int const num_elems = input.numel();
|
||||||
@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
torch::Tensor const& input, // [..., d]
|
torch::Tensor const& input, // [..., d]
|
||||||
torch::Tensor& scale) // [1]
|
torch::Tensor& scale) // [1]
|
||||||
{
|
{
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
int const block_size = 256;
|
int const block_size = 256;
|
||||||
int const num_tokens = input.numel() / input.size(-1);
|
int const num_tokens = input.numel() / input.size(-1);
|
||||||
int const num_elems = input.numel();
|
int const num_elems = input.numel();
|
||||||
|
|||||||
@ -26,6 +26,7 @@ CUDA_DEVICES = [
|
|||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("strided_input", [False, True])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_rms_norm(
|
def test_rms_norm(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
@ -34,13 +35,17 @@ def test_rms_norm(
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
|
strided_input: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
scale = 1 / (2 * hidden_size)
|
scale = 1 / (2 * hidden_size)
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||||
|
x = torch.randn(num_tokens, last_dim, dtype=dtype)
|
||||||
|
x = x[..., :hidden_size]
|
||||||
|
assert x.is_contiguous() != strided_input
|
||||||
x *= scale
|
x *= scale
|
||||||
residual = torch.randn_like(x) * scale if add_residual else None
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
|
||||||
@ -72,6 +77,7 @@ def test_rms_norm(
|
|||||||
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
|
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("strided_input", [False, True])
|
||||||
def test_fused_rms_norm_quant(
|
def test_fused_rms_norm_quant(
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@ -80,13 +86,18 @@ def test_fused_rms_norm_quant(
|
|||||||
quant_scale: float,
|
quant_scale: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
|
strided_input: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
|
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
|
||||||
scale = 1 / (2 * hidden_size)
|
scale = 1 / (2 * hidden_size)
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
last_dim = 2 * hidden_size if strided_input else hidden_size
|
||||||
|
x_base = torch.randn(num_tokens, last_dim, dtype=dtype)
|
||||||
|
x = x_base[..., :hidden_size]
|
||||||
|
assert x.is_contiguous() != strided_input
|
||||||
|
|
||||||
x *= scale
|
x *= scale
|
||||||
if add_residual:
|
if add_residual:
|
||||||
residual = torch.randn_like(x) * scale
|
residual = torch.randn_like(x) * scale
|
||||||
@ -106,9 +117,11 @@ def test_fused_rms_norm_quant(
|
|||||||
|
|
||||||
# Unfused kernel is in-place so it goes second
|
# Unfused kernel is in-place so it goes second
|
||||||
# Also use a separate clone of x to avoid modifying the input
|
# Also use a separate clone of x to avoid modifying the input
|
||||||
x_unfused = x.clone()
|
x_unfused_base = x_base.clone()
|
||||||
|
x_unfused = x_unfused_base[..., :hidden_size]
|
||||||
|
assert x_unfused.is_contiguous() != strided_input
|
||||||
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
||||||
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
|
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(),
|
||||||
quant_scale_t)
|
quant_scale_t)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -116,7 +129,6 @@ def test_fused_rms_norm_quant(
|
|||||||
residual,
|
residual,
|
||||||
atol=1e-2,
|
atol=1e-2,
|
||||||
rtol=1e-2)
|
rtol=1e-2)
|
||||||
|
|
||||||
opcheck(
|
opcheck(
|
||||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
||||||
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
|
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
|
||||||
@ -131,7 +143,7 @@ def test_fused_rms_norm_quant(
|
|||||||
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
|
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
|
||||||
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
|
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
|
||||||
|
|
||||||
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
|
torch.testing.assert_close(out_quant.to(dtype=torch.float32),
|
||||||
out_quant.to(dtype=torch.float32),
|
out_quant_fused.to(dtype=torch.float32),
|
||||||
atol=1e-3,
|
atol=1e-3,
|
||||||
rtol=1e-3)
|
rtol=1e-3)
|
||||||
|
|||||||
@ -259,6 +259,8 @@ class LinearBase(torch.nn.Module):
|
|||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.prefix = prefix
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[
|
self.quant_method: Optional[
|
||||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||||
@ -300,6 +302,12 @@ class ReplicatedLinear(LinearBase):
|
|||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
):
|
):
|
||||||
|
# If MergedReplicatedLinear, use output size of each partition.
|
||||||
|
if hasattr(self, "output_sizes"):
|
||||||
|
self.output_partition_sizes = self.output_sizes
|
||||||
|
else:
|
||||||
|
self.output_partition_sizes = [output_size]
|
||||||
|
|
||||||
super().__init__(input_size,
|
super().__init__(input_size,
|
||||||
output_size,
|
output_size,
|
||||||
skip_bias_add,
|
skip_bias_add,
|
||||||
@ -311,7 +319,8 @@ class ReplicatedLinear(LinearBase):
|
|||||||
# All the linear layer supports quant method.
|
# All the linear layer supports quant method.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(self,
|
self.quant_method.create_weights(self,
|
||||||
self.input_size, [self.output_size],
|
self.input_size,
|
||||||
|
self.output_partition_sizes,
|
||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
self.params_dtype,
|
self.params_dtype,
|
||||||
@ -367,6 +376,73 @@ class ReplicatedLinear(LinearBase):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class MergedReplicatedLinear(ReplicatedLinear):
|
||||||
|
"""Replicated linear layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: input dimension of the linear layer.
|
||||||
|
output_size: output dimension of the linear layer.
|
||||||
|
bias: If true, add bias.
|
||||||
|
skip_bias_add: If true, skip adding bias but instead return it.
|
||||||
|
params_dtype: Data type for the parameters.
|
||||||
|
quant_config: Quantization configure.
|
||||||
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
|
(e.g. model.layers.0.qkv_proj)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_sizes: list[int],
|
||||||
|
bias: bool = True,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
|
self.output_sizes = output_sizes
|
||||||
|
super().__init__(input_size,
|
||||||
|
sum(output_sizes),
|
||||||
|
bias,
|
||||||
|
skip_bias_add,
|
||||||
|
params_dtype,
|
||||||
|
quant_config,
|
||||||
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
|
def weight_loader(self,
|
||||||
|
param: Union[Parameter, BasevLLMParameter],
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
loaded_shard_id: Optional[int] = None):
|
||||||
|
assert loaded_shard_id is not None
|
||||||
|
assert loaded_shard_id < len(self.output_sizes)
|
||||||
|
|
||||||
|
if isinstance(param, BlockQuantScaleParameter):
|
||||||
|
from vllm.model_executor.layers.quantization.fp8 import (
|
||||||
|
Fp8LinearMethod, Fp8MoEMethod)
|
||||||
|
assert self.quant_method is not None
|
||||||
|
assert isinstance(self.quant_method,
|
||||||
|
(Fp8LinearMethod, Fp8MoEMethod))
|
||||||
|
weight_block_size = self.quant_method.quant_config.weight_block_size
|
||||||
|
assert weight_block_size is not None
|
||||||
|
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||||
|
shard_offset = (
|
||||||
|
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||||
|
block_n)
|
||||||
|
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||||
|
block_n)
|
||||||
|
elif isinstance(param, PerTensorScaleParameter):
|
||||||
|
shard_offset = loaded_shard_id
|
||||||
|
shard_size = 1
|
||||||
|
else:
|
||||||
|
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||||
|
shard_size = self.output_sizes[loaded_shard_id]
|
||||||
|
|
||||||
|
param[shard_offset:shard_offset + shard_size] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinear(LinearBase):
|
class ColumnParallelLinear(LinearBase):
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
|
|||||||
@ -257,9 +257,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
f"{input_size_per_partition} is not divisible by "
|
f"{input_size_per_partition} is not divisible by "
|
||||||
f"weight quantization block_k = {block_k}.")
|
f"weight quantization block_k = {block_k}.")
|
||||||
# Required by column parallel or enabling merged weights
|
# Required by column parallel or enabling merged weights
|
||||||
if (tp_size > 1 and output_size // output_size_per_partition
|
is_tp_split = (tp_size > 1 and
|
||||||
== tp_size) or len(output_partition_sizes) > 1:
|
output_size // output_size_per_partition == tp_size)
|
||||||
for output_partition_size in output_partition_sizes:
|
is_merged_gemm = len(output_partition_sizes) > 1
|
||||||
|
if is_tp_split or is_merged_gemm:
|
||||||
|
sizes_to_check = output_partition_sizes
|
||||||
|
if not is_tp_split and is_merged_gemm:
|
||||||
|
# In case of merged matrices, we allow the last
|
||||||
|
# matrix to not be a multiple of block size
|
||||||
|
sizes_to_check = output_partition_sizes[:-1]
|
||||||
|
for output_partition_size in sizes_to_check:
|
||||||
if output_partition_size % block_n != 0:
|
if output_partition_size % block_n != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Weight output_partition_size = "
|
f"Weight output_partition_size = "
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
|
MergedReplicatedLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -336,7 +337,7 @@ class DeepseekV2Attention(nn.Module):
|
|||||||
kv_a, _ = latent_cache.split(
|
kv_a, _ = latent_cache.split(
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
latent_cache = latent_cache.unsqueeze(1)
|
latent_cache = latent_cache.unsqueeze(1)
|
||||||
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
kv_a = self.kv_a_layernorm(kv_a)
|
||||||
kv = self.kv_b_proj(kv_a)[0]
|
kv = self.kv_b_proj(kv_a)[0]
|
||||||
kv = kv.view(-1, self.num_local_heads,
|
kv = kv.view(-1, self.num_local_heads,
|
||||||
self.qk_nope_head_dim + self.v_head_dim)
|
self.qk_nope_head_dim + self.v_head_dim)
|
||||||
@ -407,14 +408,24 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
self.fused_qkv_a_proj = MergedReplicatedLinear(
|
||||||
self.q_lora_rank,
|
self.hidden_size,
|
||||||
bias=False,
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||||
quant_config=quant_config,
|
bias=False,
|
||||||
prefix=f"{prefix}.q_a_proj")
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.fused_qkv_a_proj")
|
||||||
|
else:
|
||||||
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||||
|
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
|
||||||
self.num_heads *
|
self.num_heads *
|
||||||
self.qk_head_dim,
|
self.qk_head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
@ -427,13 +438,6 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.q_proj")
|
prefix=f"{prefix}.q_proj")
|
||||||
|
|
||||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
||||||
self.hidden_size,
|
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
self.kv_b_proj = ColumnParallelLinear(
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
@ -495,15 +499,24 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
q_c = None
|
||||||
|
kv_lora = None
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q_c = self.q_a_proj(hidden_states)[0]
|
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||||
|
q_c, kv_lora = qkv_lora.split(
|
||||||
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
q_c = self.q_a_layernorm(q_c)
|
q_c = self.q_a_layernorm(q_c)
|
||||||
q = self.q_b_proj(q_c)[0]
|
q = self.q_b_proj(q_c)[0]
|
||||||
else:
|
else:
|
||||||
|
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
q = self.q_proj(hidden_states)[0]
|
q = self.q_proj(hidden_states)[0]
|
||||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
|
||||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
|
||||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
dim=-1)
|
||||||
|
kv_c_normed = self.kv_a_layernorm(kv_c)
|
||||||
|
|
||||||
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
|
||||||
# Add head dim of 1 to k_pe
|
# Add head dim of 1 to k_pe
|
||||||
@ -837,6 +850,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
|||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||||
|
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
@ -871,6 +886,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
|
|||||||
if (("mlp.experts." in name) and name not in params_dict):
|
if (("mlp.experts." in name) and name not in params_dict):
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
# QKV fusion is optional, fall back to normal
|
||||||
|
# weight loading if it's not enabled
|
||||||
|
if ((param_name == "fused_qkv_a_proj")
|
||||||
|
and name not in params_dict):
|
||||||
|
continue
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user