mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Quantization] [Performance] Enable Marlin GEMM kernels for the calibration-free RTN-based quantization (#26051)
Signed-off-by: Alex Kogan <alex.kogan@oracle.com> Signed-off-by: Alex Kogan <82225080+sakogan@users.noreply.github.com>
This commit is contained in:
parent
f89f599395
commit
89342ce4c0
@ -6,21 +6,16 @@ import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
@ -31,6 +26,12 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_rtn_marlin_linear,
|
||||
marlin_make_workspace_new,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""By default, use 8 bit as target precision, but it can be
|
||||
@ -41,6 +42,9 @@ NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
|
||||
overridden by setting the RTN_GROUP_SIZE envvar
|
||||
"""
|
||||
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
|
||||
"""Global Marlin workspace shared by all modules
|
||||
"""
|
||||
workspace = None
|
||||
|
||||
|
||||
class RTNConfig(QuantizationConfig):
|
||||
@ -60,6 +64,10 @@ class RTNConfig(QuantizationConfig):
|
||||
f"supported for RTN, but got {self.weight_bits} bits."
|
||||
)
|
||||
|
||||
self.quant_type = (
|
||||
scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
|
||||
@ -221,7 +229,15 @@ class RTNLinearMethod(LinearMethodBase):
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
fix_weights(layer, "weight")
|
||||
"""Repack weights and scales for Marlin kernels."""
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
|
||||
weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)
|
||||
|
||||
replace_parameter(layer, "weight", weight)
|
||||
replace_parameter(layer, "scale", scale)
|
||||
|
||||
init_workspace(layer.weight.device)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -229,16 +245,16 @@ class RTNLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.weight
|
||||
scale = layer.scale
|
||||
|
||||
weight = rtn_dequantize(qweight, scale)
|
||||
out = F.linear(x, weight)
|
||||
del weight
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
|
||||
return out
|
||||
return apply_rtn_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.scale,
|
||||
workspace=workspace,
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class RTNMoEMethod(FusedMoEMethodBase):
|
||||
@ -315,28 +331,27 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Repack weights and scales for Marlin kernels."""
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
fix_weights(layer, "w13_weight", weight_bits == 4)
|
||||
fix_weights(layer, "w2_weight", weight_bits == 4)
|
||||
|
||||
w13_weight, w13_scale = repack_weights(
|
||||
layer.w13_weight, layer.w13_scale, weight_bits
|
||||
)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w13_scale", w13_scale)
|
||||
|
||||
w2_weight, w2_scale = repack_weights(
|
||||
layer.w2_weight, layer.w2_scale, weight_bits
|
||||
)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
replace_parameter(layer, "w2_scale", w2_scale)
|
||||
|
||||
init_workspace(layer.w13_weight.device)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (
|
||||
int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4
|
||||
else int8_w8a16_moe_quant_config
|
||||
)
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, group_size],
|
||||
)
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -366,8 +381,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -383,18 +396,22 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
getattr(layer, "w13_bias", None),
|
||||
getattr(layer, "w2_bias", None),
|
||||
layer.w13_scale,
|
||||
layer.w2_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=self.quant_config.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
workspace=workspace,
|
||||
)
|
||||
|
||||
|
||||
@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
return input_deq
|
||||
|
||||
|
||||
def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False):
|
||||
"""torch.compile does not know how to deal with a Parameter subclass
|
||||
(aka RTNParameter). As we don't really need RTNParameters for the
|
||||
forward pass, we replace them with equivalent instances of Parameters.
|
||||
def _get_perms():
|
||||
perm = []
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm_arr = np.array(perm)
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
|
||||
perm_tensor = torch.from_numpy(perm_arr)
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return perm_tensor, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
_perm, _scale_perm, _scale_perm_single = _get_perms()
|
||||
|
||||
|
||||
def pack_for_marlin(weight, scale, qbits):
|
||||
batch = weight.shape[0]
|
||||
|
||||
n = weight.size(1)
|
||||
k = weight.size(2)
|
||||
groupsize = k // scale.size(2)
|
||||
|
||||
tile = 16
|
||||
s = scale.permute(0, 2, 1) # transpose
|
||||
w = weight.permute(0, 2, 1) # transpose
|
||||
if groupsize != k:
|
||||
w = w.reshape((batch, -1, groupsize, n))
|
||||
w = w.permute(0, 2, 1, 3)
|
||||
w = w.reshape((batch, groupsize, -1))
|
||||
s = s.reshape((batch, 1, -1))
|
||||
|
||||
if groupsize != k:
|
||||
w = w.reshape((batch, groupsize, -1, n))
|
||||
w = w.permute(0, 2, 1, 3)
|
||||
w = w.reshape((batch, k, n)).contiguous()
|
||||
s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
|
||||
else:
|
||||
s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
|
||||
s = s.reshape((batch, -1, n)).contiguous()
|
||||
w = w.reshape((batch, k // tile, tile, n // tile, tile))
|
||||
w = w.permute((0, 1, 3, 2, 4))
|
||||
w = w.reshape((batch, k // tile, n * tile))
|
||||
res = w
|
||||
res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
|
||||
if qbits == 4:
|
||||
q = torch.zeros(
|
||||
(batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
|
||||
)
|
||||
for i in range(2):
|
||||
q |= res[:, :, i::2] << 4 * i
|
||||
q = q.reshape(batch, -1, n).contiguous()
|
||||
else:
|
||||
q = res.clone()
|
||||
q[:, :, 2::8] = res[:, :, 4::8]
|
||||
q[:, :, 3::8] = res[:, :, 5::8]
|
||||
q[:, :, 4::8] = res[:, :, 2::8]
|
||||
q[:, :, 5::8] = res[:, :, 3::8]
|
||||
q = q.reshape(batch, -1, n).to(torch.int8).contiguous()
|
||||
|
||||
return q, s
|
||||
|
||||
|
||||
def repack_8bit_into_32bit(input):
|
||||
output = torch.zeros(
|
||||
(input.shape[0], input.shape[1], input.shape[2] // 4),
|
||||
dtype=torch.int32,
|
||||
device=input.device,
|
||||
)
|
||||
for i in range(4):
|
||||
output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def repack_weights(qweight, scale, weight_bits):
|
||||
batch_present = len(qweight.shape) == 3
|
||||
if not batch_present:
|
||||
qweight = qweight.unsqueeze(0)
|
||||
scale = scale.unsqueeze(0)
|
||||
|
||||
if weight_bits == 4:
|
||||
"""Unpack two 4-bit values from each byte.
|
||||
"""
|
||||
old_weight = getattr(layer, param_name)
|
||||
assert isinstance(old_weight, RTNParameter)
|
||||
data = old_weight.data.data
|
||||
qweight_unpacked = torch.empty(
|
||||
(qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
|
||||
dtype=torch.uint8,
|
||||
device=qweight.device,
|
||||
)
|
||||
for i in range(2):
|
||||
qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
|
||||
qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
|
||||
)
|
||||
else:
|
||||
qweight_unpacked = qweight
|
||||
|
||||
delattr(layer, param_name)
|
||||
qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
|
||||
"""Marlin kernels expect tensors in int32 format in a certain shape
|
||||
"""
|
||||
qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
|
||||
qweight_reshaped = qweight_repacked.reshape(
|
||||
qweight.shape[0], qweight.shape[2] // 16, -1
|
||||
)
|
||||
if not batch_present:
|
||||
qweight_reshaped = qweight_reshaped.squeeze(0)
|
||||
scale_packed = scale_packed.squeeze(0)
|
||||
|
||||
if reshape:
|
||||
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
|
||||
new_weight = Parameter(data=data, requires_grad=False)
|
||||
layer.register_parameter(param_name, new_weight)
|
||||
return qweight_reshaped, scale_packed
|
||||
|
||||
|
||||
def init_workspace(device):
|
||||
global workspace
|
||||
if workspace is None:
|
||||
workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
|
||||
)
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def apply_rtn_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: torch.Tensor | None = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(
|
||||
m=reshaped_x.size(0),
|
||||
n=output_size_per_partition,
|
||||
k=reshaped_x.size(1),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
bias,
|
||||
weight_scale,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
workspace,
|
||||
quant_type,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user