From bf03ff3575c8b0bc42517ffaef0df820dd3a806e Mon Sep 17 00:00:00 2001 From: Jacob Manning Date: Wed, 9 Jul 2025 16:17:55 -0400 Subject: [PATCH] [Kernel] Add Conch backend for mixed-precision linear layer (#19818) Signed-off-by: Jacob Manning --- requirements/rocm.txt | 1 + .../kernels/mixed_precision/__init__.py | 5 +- .../kernels/mixed_precision/conch.py | 92 +++++++++++++++++++ .../kernels/mixed_precision/machete.py | 4 + .../kernels/mixed_precision/marlin.py | 4 + 5 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 988329c3a212d..7038c9024c6b6 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 +conch-triton-kernels==1.2.1 diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 0bf0d530d2351..21e5ae793c3f5 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -8,6 +8,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark im AllSparkLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 BitBLASLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 + ConchLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -24,6 +26,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ AllSparkLinearKernel, MarlinLinearKernel, BitBLASLinearKernel, + ConchLinearKernel, ExllamaLinearKernel, ] @@ -80,4 +83,4 @@ def choose_mp_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) \ No newline at end of file + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py new file mode 100644 index 0000000000000..f80af548f0199 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from importlib.util import find_spec +from typing import Final, Optional + +import torch + +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [ + scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8, + scalar_types.uint8b128 +] +_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128] + + +class ConchLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES: + error_msg = f"Weight type ({c.weight_type}) not supported by "\ + "ConchLinearKernel, supported types are: " \ + f"{_CONCH_SUPPORTED_WEIGHT_TYPES}" + return False, error_msg + + if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES: + error_msg = f"Group size ({c.group_size}) not supported by "\ + "ConchLinearKernel, supported group sizes are: " \ + f"{_CONCH_SUPPORTED_GROUP_SIZES}" + return False, error_msg + + if find_spec("conch") is None: + error_msg = "conch-triton-kernels is not installed, please "\ + "install it via `pip install conch-triton-kernels` "\ + "and try again!" + return False, error_msg + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = x.data.contiguous() + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + from conch.ops.quantization.gemm import mixed_precision_gemm + + w_q, w_s, w_zp, _ = self._get_weight_params(layer) + + output = mixed_precision_gemm( + x=x, + w_q_packed=w_q.data, + w_s=w_s.data, + w_zp=w_zp.data if w_zp is not None else None, + weight_size_bits=self.config.weight_type.size_bits, + weight_bias=self.config.weight_type.bias, + group_size=self.config.group_size, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index 12eb9d104bf20..851fd155465d4 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( pack_quantized_values_into_int32, unpack_quantized_values_into_int32) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) +from vllm.platforms import current_platform from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -27,6 +28,9 @@ class MacheteLinearKernel(MPLinearKernel): @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + # Machete uses CUTLASS, so it can only be compatible with Nvidia + if not current_platform.is_cuda(): + return False, "Machete only supported on CUDA" if c.has_g_idx and\ c.partition_weight_shape[0] != c.full_weight_shape[0]: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py index 1597492a5cf65..73e0b17ea85aa 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_zero_points, query_marlin_supported_quant_types, unpack_cols) from vllm.model_executor.parameter import (BasevLLMParameter, permute_param_layout_) +from vllm.platforms import current_platform from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig @@ -26,6 +27,9 @@ class MarlinLinearKernel(MPLinearKernel): @classmethod def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + # Marlin uses inline PTX, so it can only be compatible with Nvidia + if not current_platform.is_cuda(): + return False, "Marlin only supported on CUDA" quant_types = query_marlin_supported_quant_types(c.zero_points) if c.weight_type not in quant_types: