[Kernel] Add Conch backend for mixed-precision linear layer (#19818)

Signed-off-by: Jacob Manning <jmanning+oss@stackav.com>
This commit is contained in:
Jacob Manning 2025-07-09 16:17:55 -04:00 committed by GitHub
parent 47043eb678
commit bf03ff3575
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 105 additions and 1 deletions

View File

@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
conch-triton-kernels==1.2.1

View File

@ -8,6 +8,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark im
AllSparkLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
BitBLASLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
ConchLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@ -24,6 +26,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
AllSparkLinearKernel,
MarlinLinearKernel,
BitBLASLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
]
@ -80,4 +83,4 @@ def choose_mp_linear_kernel(
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
+ '\n'.join(failure_reasons))

View File

@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from typing import Final, Optional
import torch
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8,
scalar_types.uint8b128
]
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
class ConchLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
error_msg = f"Weight type ({c.weight_type}) not supported by "\
"ConchLinearKernel, supported types are: " \
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
return False, error_msg
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
error_msg = f"Group size ({c.group_size}) not supported by "\
"ConchLinearKernel, supported group sizes are: " \
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
return False, error_msg
if find_spec("conch") is None:
error_msg = "conch-triton-kernels is not installed, please "\
"install it via `pip install conch-triton-kernels` "\
"and try again!"
return False, error_msg
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = x.data.contiguous()
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
from conch.ops.quantization.gemm import mixed_precision_gemm
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
output = mixed_precision_gemm(
x=x,
w_q_packed=w_q.data,
w_s=w_s.data,
w_zp=w_zp.data if w_zp is not None else None,
weight_size_bits=self.config.weight_type.size_bits,
weight_bias=self.config.weight_type.bias,
group_size=self.config.group_size,
)
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@ -27,6 +28,9 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
# Machete uses CUTLASS, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Machete only supported on CUDA"
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
@ -26,6 +27,9 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
# Marlin uses inline PTX, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Marlin only supported on CUDA"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types: