mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:45:01 +08:00
[Kernel] Add Conch backend for mixed-precision linear layer (#19818)
Signed-off-by: Jacob Manning <jmanning+oss@stackav.com>
This commit is contained in:
parent
47043eb678
commit
bf03ff3575
@ -17,3 +17,4 @@ setuptools>=77.0.3,<80.0.0
|
|||||||
setuptools-scm>=8
|
setuptools-scm>=8
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer==0.11.0
|
||||||
runai-model-streamer-s3==0.11.0
|
runai-model-streamer-s3==0.11.0
|
||||||
|
conch-triton-kernels==1.2.1
|
||||||
|
|||||||
@ -8,6 +8,8 @@ from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark im
|
|||||||
AllSparkLinearKernel)
|
AllSparkLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
|
||||||
BitBLASLinearKernel)
|
BitBLASLinearKernel)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501
|
||||||
|
ConchLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||||
ExllamaLinearKernel)
|
ExllamaLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||||
@ -24,6 +26,7 @@ _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [
|
|||||||
AllSparkLinearKernel,
|
AllSparkLinearKernel,
|
||||||
MarlinLinearKernel,
|
MarlinLinearKernel,
|
||||||
BitBLASLinearKernel,
|
BitBLASLinearKernel,
|
||||||
|
ConchLinearKernel,
|
||||||
ExllamaLinearKernel,
|
ExllamaLinearKernel,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,92 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from importlib.util import find_spec
|
||||||
|
from typing import Final, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
permute_param_layout_)
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
|
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
|
||||||
|
scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8,
|
||||||
|
scalar_types.uint8b128
|
||||||
|
]
|
||||||
|
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
|
||||||
|
|
||||||
|
|
||||||
|
class ConchLinearKernel(MPLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||||
|
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||||
|
error_msg = f"Weight type ({c.weight_type}) not supported by "\
|
||||||
|
"ConchLinearKernel, supported types are: " \
|
||||||
|
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
|
||||||
|
error_msg = f"Group size ({c.group_size}) not supported by "\
|
||||||
|
"ConchLinearKernel, supported group sizes are: " \
|
||||||
|
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
if find_spec("conch") is None:
|
||||||
|
error_msg = "conch-triton-kernels is not installed, please "\
|
||||||
|
"install it via `pip install conch-triton-kernels` "\
|
||||||
|
"and try again!"
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
# note assumes that
|
||||||
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
|
def transform_w_q(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
x.data = x.data.contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def transform_w_s(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||||
|
x.data = x.data.contiguous()
|
||||||
|
return x
|
||||||
|
|
||||||
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||||
|
|
||||||
|
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
output = mixed_precision_gemm(
|
||||||
|
x=x,
|
||||||
|
w_q_packed=w_q.data,
|
||||||
|
w_s=w_s.data,
|
||||||
|
w_zp=w_zp.data if w_zp is not None else None,
|
||||||
|
weight_size_bits=self.config.weight_type.size_bits,
|
||||||
|
weight_bias=self.config.weight_type.bias,
|
||||||
|
group_size=self.config.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output
|
||||||
@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
permute_param_layout_)
|
permute_param_layout_)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
@ -27,6 +28,9 @@ class MacheteLinearKernel(MPLinearKernel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls,
|
def can_implement(cls,
|
||||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||||
|
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "Machete only supported on CUDA"
|
||||||
|
|
||||||
if c.has_g_idx and\
|
if c.has_g_idx and\
|
||||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
permute_param_layout_)
|
permute_param_layout_)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
@ -26,6 +27,9 @@ class MarlinLinearKernel(MPLinearKernel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls,
|
def can_implement(cls,
|
||||||
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]:
|
||||||
|
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "Marlin only supported on CUDA"
|
||||||
|
|
||||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||||
if c.weight_type not in quant_types:
|
if c.weight_type not in quant_types:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user