diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index d073dd6d2dee..f051eb070222 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -15,15 +15,16 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride, const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * input_stride + idx]; variance += x * x; } @@ -37,7 +38,7 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; + float x = (float)input[blockIdx.x * input_stride + idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } @@ -50,7 +51,8 @@ __global__ void rms_norm_kernel( template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -59,6 +61,7 @@ fused_add_rms_norm_kernel( static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; + const int64_t vec_input_stride = input_stride / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are @@ -73,7 +76,8 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; - _f16Vec temp = input_v[id]; + int64_t strided_id = blockIdx.x * vec_input_stride + idx; + _f16Vec temp = input_v[strided_id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; @@ -90,10 +94,11 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; + int64_t strided_id = blockIdx.x * vec_input_stride + idx; _f16Vec temp = residual_v[id]; temp *= s_variance; temp *= weight_v[idx]; - input_v[id] = temp; + input_v[strided_id] = temp; } } @@ -103,7 +108,8 @@ fused_add_rms_norm_kernel( template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -111,7 +117,7 @@ fused_add_rms_norm_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - scalar_t z = input[blockIdx.x * hidden_size + idx]; + scalar_t z = input[blockIdx.x * input_stride + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; @@ -129,7 +135,7 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = + input[blockIdx.x * input_stride + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } @@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; + int64_t input_stride = input.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); @@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), + out.data_ptr(), input.data_ptr(), input_stride, weight.data_ptr(), epsilon, num_tokens, hidden_size); }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>(input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), epsilon, \ - num_tokens, hidden_size); \ +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>( \ + input.data_ptr(), input_stride, \ + residual.data_ptr(), weight.data_ptr(), \ + epsilon, num_tokens, hidden_size); \ }); void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); + int64_t input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = - inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0) { + constexpr int vector_width = 8; + constexpr int req_alignment_bytes = + vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32 + // falls back to non-vectorized version anyway) + bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 && + res_ptr % req_alignment_bytes == 0 && + wt_ptr % req_alignment_bytes == 0; + bool offsets_are_multiple_of_vector_width = + hidden_size % vector_width == 0 && input_stride % vector_width == 0; + if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index d595b9e889c8..0fd5849d9626 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -23,8 +23,9 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { @@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * input_stride + idx]; variance += x * x; } @@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( float const scale_inv = 1.0f / *scale; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; + float x = (float)input[blockIdx.x * input_stride + idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion(out_norm, scale_inv); @@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel( template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - scalar_t* __restrict__ input, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; + const int vec_input_stride = input_stride / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are @@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel( reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int stride_id = blockIdx.x * vec_input_stride + idx; int id = blockIdx.x * vec_hidden_size + idx; - _f16Vec temp = input_v[id]; + _f16Vec temp = input_v[stride_id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; @@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel( template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - scalar_t* __restrict__ input, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - scalar_t z = input[blockIdx.x * hidden_size + idx]; + scalar_t z = input[blockIdx.x * input_stride + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; @@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] torch::Tensor& scale, // [1] double epsilon) { + TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); + int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] vllm::rms_norm_static_fp8_quant_kernel <<>>( out.data_ptr(), input.data_ptr(), - weight.data_ptr(), scale.data_ptr(), - epsilon, num_tokens, hidden_size); + input_stride, weight.data_ptr(), + scale.data_ptr(), epsilon, num_tokens, + hidden_size); }); }); } @@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] width, fp8_t> \ <<>>( \ out.data_ptr(), input.data_ptr(), \ - residual.data_ptr(), \ + input_stride, residual.data_ptr(), \ weight.data_ptr(), scale.data_ptr(), \ epsilon, num_tokens, hidden_size); \ }); \ @@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant( torch::Tensor& weight, // [hidden_size] torch::Tensor& scale, // [1] double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(residual.is_contiguous()); int hidden_size = input.size(-1); + int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0) { + if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index f3f9f669e00a..0e1eab66f0b9 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); int const block_size = 256; int const num_tokens = input.numel() / input.size(-1); int const num_elems = input.numel(); @@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); int const block_size = 256; int const num_tokens = input.numel() / input.size(-1); int const num_elems = input.numel(); diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 3eac062738f8..02316ceaac73 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -26,6 +26,7 @@ CUDA_DEVICES = [ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("strided_input", [False, True]) @torch.inference_mode() def test_rms_norm( num_tokens: int, @@ -34,13 +35,17 @@ def test_rms_norm( dtype: torch.dtype, seed: int, device: str, + strided_input: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) + last_dim = 2 * hidden_size if strided_input else hidden_size + x = torch.randn(num_tokens, last_dim, dtype=dtype) + x = x[..., :hidden_size] + assert x.is_contiguous() != strided_input x *= scale residual = torch.randn_like(x) * scale if add_residual else None @@ -72,6 +77,7 @@ def test_rms_norm( @pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("strided_input", [False, True]) def test_fused_rms_norm_quant( num_tokens: int, hidden_size: int, @@ -80,13 +86,18 @@ def test_fused_rms_norm_quant( quant_scale: float, seed: int, device: str, + strided_input: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) + last_dim = 2 * hidden_size if strided_input else hidden_size + x_base = torch.randn(num_tokens, last_dim, dtype=dtype) + x = x_base[..., :hidden_size] + assert x.is_contiguous() != strided_input + x *= scale if add_residual: residual = torch.randn_like(x) * scale @@ -106,9 +117,11 @@ def test_fused_rms_norm_quant( # Unfused kernel is in-place so it goes second # Also use a separate clone of x to avoid modifying the input - x_unfused = x.clone() + x_unfused_base = x_base.clone() + x_unfused = x_unfused_base[..., :hidden_size] + assert x_unfused.is_contiguous() != strided_input torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, + torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(), quant_scale_t) torch.cuda.synchronize() @@ -116,7 +129,6 @@ def test_fused_rms_norm_quant( residual, atol=1e-2, rtol=1e-2) - opcheck( torch.ops._C.fused_add_rms_norm_static_fp8_quant, (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) @@ -131,7 +143,7 @@ def test_fused_rms_norm_quant( opcheck(torch.ops._C.rms_norm_static_fp8_quant, (out_quant_fused, x, weight, quant_scale_t, 1e-6)) - torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), - out_quant.to(dtype=torch.float32), + torch.testing.assert_close(out_quant.to(dtype=torch.float32), + out_quant_fused.to(dtype=torch.float32), atol=1e-3, rtol=1e-3) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 366dfd97d816..bb81a663d454 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -259,6 +259,8 @@ class LinearBase(torch.nn.Module): if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + self.quant_config = quant_config + self.prefix = prefix if quant_config is None: self.quant_method: Optional[ QuantizeMethodBase] = UnquantizedLinearMethod() @@ -300,6 +302,12 @@ class ReplicatedLinear(LinearBase): *, return_bias: bool = True, ): + # If MergedReplicatedLinear, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = self.output_sizes + else: + self.output_partition_sizes = [output_size] + super().__init__(input_size, output_size, skip_bias_add, @@ -311,7 +319,8 @@ class ReplicatedLinear(LinearBase): # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, - self.input_size, [self.output_size], + self.input_size, + self.output_partition_sizes, self.input_size, self.output_size, self.params_dtype, @@ -367,6 +376,73 @@ class ReplicatedLinear(LinearBase): return s +class MergedReplicatedLinear(ReplicatedLinear): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.output_sizes = output_sizes + super().__init__(input_size, + sum(output_sizes), + bias, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias) + + def weight_loader(self, + param: Union[Parameter, BasevLLMParameter], + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + assert loaded_shard_id is not None + assert loaded_shard_id < len(self.output_sizes) + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n) + elif isinstance(param, PerTensorScaleParameter): + shard_offset = loaded_shard_id + shard_size = 1 + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + + param[shard_offset:shard_offset + shard_size] = loaded_weight + + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 35d7545d8c6a..75f8adf34f7d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -257,9 +257,16 @@ class Fp8LinearMethod(LinearMethodBase): f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}.") # Required by column parallel or enabling merged weights - if (tp_size > 1 and output_size // output_size_per_partition - == tp_size) or len(output_partition_sizes) > 1: - for output_partition_size in output_partition_sizes: + is_tp_split = (tp_size > 1 and + output_size // output_size_per_partition == tp_size) + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: if output_partition_size % block_n != 0: raise ValueError( f"Weight output_partition_size = " diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 5106b9914b5e..649109777b3f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -42,6 +42,7 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -336,7 +337,7 @@ class DeepseekV2Attention(nn.Module): kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -407,14 +408,24 @@ class DeepseekV2MLAAttention(nn.Module): self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") + self.fused_qkv_a_proj = MergedReplicatedLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj") + else: + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + + if self.q_lora_rank is not None: self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, @@ -427,13 +438,6 @@ class DeepseekV2MLAAttention(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( @@ -495,15 +499,24 @@ class DeepseekV2MLAAttention(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: + q_c = None + kv_lora = None + if self.q_lora_rank is not None: - q_c = self.q_a_proj(hidden_states)[0] + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) q_c = self.q_a_layernorm(q_c) q = self.q_b_proj(q_c)[0] else: + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], + dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_local_heads, self.qk_head_dim) # Add head dim of 1 to k_pe @@ -837,6 +850,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -871,6 +886,12 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): if (("mlp.experts." in name) and name not in params_dict): continue name = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + if ((param_name == "fused_qkv_a_proj") + and name not in params_dict): + continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue