[Quantization] [Performance] Enable Marlin GEMM kernels for the calibration-free RTN-based quantization (#26051)

Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
Signed-off-by: Alex Kogan <82225080+sakogan@users.noreply.github.com>
This commit is contained in:
Alex Kogan 2025-10-13 14:52:54 -04:00 committed by GitHub
parent f89f599395
commit 89342ce4c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 233 additions and 56 deletions

View File

@ -6,21 +6,16 @@ import os
from collections.abc import Callable from collections.abc import Callable
from typing import Any, Optional from typing import Any, Optional
import numpy as np
import torch import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig, FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
) )
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
@ -31,6 +26,12 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_rtn_marlin_linear,
marlin_make_workspace_new,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be """By default, use 8 bit as target precision, but it can be
@ -41,6 +42,9 @@ NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
overridden by setting the RTN_GROUP_SIZE envvar overridden by setting the RTN_GROUP_SIZE envvar
""" """
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
"""Global Marlin workspace shared by all modules
"""
workspace = None
class RTNConfig(QuantizationConfig): class RTNConfig(QuantizationConfig):
@ -60,6 +64,10 @@ class RTNConfig(QuantizationConfig):
f"supported for RTN, but got {self.weight_bits} bits." f"supported for RTN, but got {self.weight_bits} bits."
) )
self.quant_type = (
scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})" f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
@ -221,7 +229,15 @@ class RTNLinearMethod(LinearMethodBase):
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
fix_weights(layer, "weight") """Repack weights and scales for Marlin kernels."""
weight_bits = self.quant_config.weight_bits
weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)
replace_parameter(layer, "weight", weight)
replace_parameter(layer, "scale", scale)
init_workspace(layer.weight.device)
def apply( def apply(
self, self,
@ -229,16 +245,16 @@ class RTNLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
qweight = layer.weight return apply_rtn_marlin_linear(
scale = layer.scale input=x,
weight=layer.weight,
weight = rtn_dequantize(qweight, scale) weight_scale=layer.scale,
out = F.linear(x, weight) workspace=workspace,
del weight quant_type=self.quant_config.quant_type,
if bias is not None: output_size_per_partition=layer.output_size_per_partition,
out.add_(bias) input_size_per_partition=layer.input_size_per_partition,
bias=bias,
return out )
class RTNMoEMethod(FusedMoEMethodBase): class RTNMoEMethod(FusedMoEMethodBase):
@ -315,28 +331,27 @@ class RTNMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Repack weights and scales for Marlin kernels."""
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4) w13_weight, w13_scale = repack_weights(
layer.w13_weight, layer.w13_scale, weight_bits
)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w13_scale", w13_scale)
w2_weight, w2_scale = repack_weights(
layer.w2_weight, layer.w2_scale, weight_bits
)
replace_parameter(layer, "w2_weight", w2_weight)
replace_parameter(layer, "w2_scale", w2_scale)
init_workspace(layer.w13_weight.device)
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None: ) -> FusedMoEQuantConfig | None:
weight_bits = self.quant_config.weight_bits return None
group_size = self.quant_config.group_size
assert weight_bits == 4 or weight_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if weight_bits == 4
else int8_w8a16_moe_quant_config
)
return config_builder(
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, group_size],
)
def apply( def apply(
self, self,
@ -366,8 +381,6 @@ class RTNMoEMethod(FusedMoEMethodBase):
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts( topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
@ -383,18 +396,22 @@ class RTNMoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype, indices_type=self.topk_indices_dtype,
) )
return fused_experts( return torch.ops.vllm.fused_marlin_moe(
x, x,
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_weights=topk_weights, getattr(layer, "w13_bias", None),
topk_ids=topk_ids, getattr(layer, "w2_bias", None),
inplace=True, layer.w13_scale,
activation=activation, layer.w2_scale,
router_logits,
topk_weights,
topk_ids,
quant_type_id=self.quant_config.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
quant_config=self.moe_quant_config, workspace=workspace,
) )
@ -504,18 +521,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return input_deq return input_deq
def fix_weights(layer: torch.nn.Module, param_name: str, reshape: bool = False): def _get_perms():
"""torch.compile does not know how to deal with a Parameter subclass perm = []
(aka RTNParameter). As we don't really need RTNParameters for the for i in range(32):
forward pass, we replace them with equivalent instances of Parameters. perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm_arr = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
perm_tensor = torch.from_numpy(perm_arr)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm_tensor, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
def pack_for_marlin(weight, scale, qbits):
batch = weight.shape[0]
n = weight.size(1)
k = weight.size(2)
groupsize = k // scale.size(2)
tile = 16
s = scale.permute(0, 2, 1) # transpose
w = weight.permute(0, 2, 1) # transpose
if groupsize != k:
w = w.reshape((batch, -1, groupsize, n))
w = w.permute(0, 2, 1, 3)
w = w.reshape((batch, groupsize, -1))
s = s.reshape((batch, 1, -1))
if groupsize != k:
w = w.reshape((batch, groupsize, -1, n))
w = w.permute(0, 2, 1, 3)
w = w.reshape((batch, k, n)).contiguous()
s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
else:
s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
s = s.reshape((batch, -1, n)).contiguous()
w = w.reshape((batch, k // tile, tile, n // tile, tile))
w = w.permute((0, 1, 3, 2, 4))
w = w.reshape((batch, k // tile, n * tile))
res = w
res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
if qbits == 4:
q = torch.zeros(
(batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
)
for i in range(2):
q |= res[:, :, i::2] << 4 * i
q = q.reshape(batch, -1, n).contiguous()
else:
q = res.clone()
q[:, :, 2::8] = res[:, :, 4::8]
q[:, :, 3::8] = res[:, :, 5::8]
q[:, :, 4::8] = res[:, :, 2::8]
q[:, :, 5::8] = res[:, :, 3::8]
q = q.reshape(batch, -1, n).to(torch.int8).contiguous()
return q, s
def repack_8bit_into_32bit(input):
output = torch.zeros(
(input.shape[0], input.shape[1], input.shape[2] // 4),
dtype=torch.int32,
device=input.device,
)
for i in range(4):
output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i
return output
def repack_weights(qweight, scale, weight_bits):
batch_present = len(qweight.shape) == 3
if not batch_present:
qweight = qweight.unsqueeze(0)
scale = scale.unsqueeze(0)
if weight_bits == 4:
"""Unpack two 4-bit values from each byte.
"""
qweight_unpacked = torch.empty(
(qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
dtype=torch.uint8,
device=qweight.device,
)
for i in range(2):
qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
)
else:
qweight_unpacked = qweight
qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
"""Marlin kernels expect tensors in int32 format in a certain shape
""" """
old_weight = getattr(layer, param_name) qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
assert isinstance(old_weight, RTNParameter) qweight_reshaped = qweight_repacked.reshape(
data = old_weight.data.data qweight.shape[0], qweight.shape[2] // 16, -1
)
if not batch_present:
qweight_reshaped = qweight_reshaped.squeeze(0)
scale_packed = scale_packed.squeeze(0)
delattr(layer, param_name) return qweight_reshaped, scale_packed
if reshape:
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) def init_workspace(device):
new_weight = Parameter(data=data, requires_grad=False) global workspace
layer.register_parameter(param_name, new_weight) if workspace is None:
workspace = marlin_make_workspace_new(device, 4)

View File

@ -528,3 +528,48 @@ def apply_awq_marlin_linear(
) )
return output.reshape(out_shape) return output.reshape(out_shape)
def apply_rtn_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
quant_type: ScalarType,
output_size_per_partition: int,
input_size_per_partition: int,
bias: torch.Tensor | None = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0),
n=output_size_per_partition,
k=reshaped_x.size(1),
device=input.device,
dtype=input.dtype,
)
output = ops.gptq_marlin_gemm(
reshaped_x,
None,
weight,
bias,
weight_scale,
None,
None,
None,
None,
workspace,
quant_type,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
return output.reshape(out_shape)