[Feature] Add support for MoE models in the calibration-free RTN-based quantization (#20766)

Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
This commit is contained in:
Alex Kogan 2025-07-25 21:09:34 -04:00 committed by GitHub
parent f1b286b2fb
commit 7ae75fa6d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 201 additions and 38 deletions

View File

@ -8,7 +8,10 @@ import pytest
from tests.quantization.utils import is_quant_method_supported
MODELS = ["microsoft/Phi-3-mini-4k-instruct"]
MODELS = [
"microsoft/Phi-3-mini-4k-instruct", # dense model
"ai21labs/Jamba-tiny-dev", # MoE model
]
@pytest.mark.skipif(not is_quant_method_supported("rtn"),

View File

@ -3,18 +3,19 @@
# Copyright © 2025, Oracle and/or its affiliates.
import os
from typing import Any, Optional
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
logger = init_logger(__name__)
"""By default, use 8 bit as target precision, but it can be
@ -71,9 +72,11 @@ class RTNConfig(QuantizationConfig):
return cls(weight_bits, group_size)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["RTNLinearMethod"]:
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return RTNLinearMethod(self)
elif isinstance(layer, FusedMoE):
return RTNMoEMethod(self)
return None
@ -94,11 +97,18 @@ class RTNTensor:
self.data.narrow(dim, start // factor, length // factor),
self.scale.narrow(dim, start, length), self.quant_config)
def __getitem__(self, key):
return RTNTensor(self.data[key], self.scale[key], self.quant_config)
@property
def shape(self):
shape = self.data.shape
factor = 1 if self.quant_config.weight_bits == 8 else 2
return torch.Size((shape[0] * factor, shape[1]))
batch_present = len(shape) == 3
if batch_present:
return torch.Size((shape[0], shape[1] * factor, shape[2]))
else:
return torch.Size((shape[0] * factor, shape[1]))
def copy_(self, loaded_weight: torch.Tensor) -> None:
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
@ -165,7 +175,7 @@ class RTNLinearMethod(LinearMethodBase):
weight = RTNParameter(data=torch.empty(output_size_per_partition //
factor,
input_size_per_partition,
dtype=torch.int8),
dtype=torch.uint8),
scale=scale,
quant_config=self.quant_config)
@ -180,18 +190,7 @@ class RTNLinearMethod(LinearMethodBase):
layer.output_size_per_partition = output_size_per_partition
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""torch.compile does not know how to deal with a Parameter subclass
(aka RTNParameter). As we don't really need RTNParameters for the
forward pass, we replace them with equivalent instances of Parameters.
"""
old_weight = layer.weight
assert isinstance(old_weight, RTNParameter)
data = old_weight.data.data
delattr(layer, "weight")
new_weight = Parameter(data=data, requires_grad=False)
layer.register_parameter("weight", new_weight)
fix_weights(layer, "weight")
def apply(self,
layer: torch.nn.Module,
@ -209,6 +208,128 @@ class RTNLinearMethod(LinearMethodBase):
return out
class RTNMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: RTNConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
factor = 1 if self.quant_config.weight_bits == 8 else 2
# Fused gate_up_proj (column parallel)
num_groups_per_col = (hidden_size // self.quant_config.group_size
if self.quant_config.group_size != -1 else 1)
w13_scale = Parameter(
torch.empty(num_experts,
2 * intermediate_size_per_partition,
num_groups_per_col,
dtype=params_dtype),
requires_grad=False,
)
layer.register_parameter("w13_scale", w13_scale)
w13_weight = RTNParameter(data=torch.empty(
num_experts,
2 * intermediate_size_per_partition // factor,
hidden_size,
dtype=torch.uint8),
scale=w13_scale,
quant_config=self.quant_config)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
num_groups_per_col = (intermediate_size_per_partition //
self.quant_config.group_size
if self.quant_config.group_size != -1 else 1)
w2_scale = Parameter(torch.zeros(num_experts,
hidden_size,
num_groups_per_col,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
w2_weight = RTNParameter(data=torch.empty(
num_experts,
hidden_size // factor,
intermediate_size_per_partition,
dtype=torch.uint8),
scale=w2_scale,
quant_config=self.quant_config)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_bits = self.quant_config.weight_bits
fix_weights(layer, "w13_weight", weight_bits == 4)
fix_weights(layer, "w2_weight", weight_bits == 4)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `RTNMoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
ret = fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
block_shape=[0, group_size])
return ret
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize a tensor using per-group static scaling factor.
@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int,
If equal to -1, each row in the input tensor is treated
as one group.
"""
batch_present = len(tensor.shape) == 3
if not batch_present:
tensor = tensor.unsqueeze(0)
q_range = 2**num_bits
num_groups = (tensor.shape[0] * tensor.shape[1] //
group_size if group_size != -1 else tensor.shape[0])
num_groups = (tensor.shape[1] * tensor.shape[2] //
group_size if group_size != -1 else tensor.shape[1])
"""Calculate a scaling factor per input group.
"""
input_flat = tensor.reshape(num_groups, -1)
input_min = torch.min(input_flat, dim=1, keepdim=True)[0]
input_max = torch.max(input_flat, dim=1, keepdim=True)[0]
input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
input_max_abs = torch.max(input_min.abs(), input_max.abs())
scale = (input_max_abs * 2.0 / (q_range - 1))
"""Scale each input group, truncate and round to the nearest integer.
"""Scale each input group, round to the nearest integer, shift
the range and truncate.
"""
scaled_input = input_flat / scale
scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1)
scaled_input = scaled_input.round()
scaled_input += q_range // 2
scaled_input = scaled_input.clamp(0, q_range - 1)
scale = scale.reshape(tensor.shape[0], -1).contiguous()
inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8)
scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
inputs_q = inputs_q.contiguous()
if num_bits == 4:
"""Pack two 4-bit values into each byte.
"""
inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf)
inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1])
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf)
inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2,
tensor.shape[2])
inputs_q = inputs_q.contiguous()
if not batch_present:
inputs_q = inputs_q.squeeze(0)
scale = scale.squeeze(0)
return inputs_q, scale
@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
tensor: The input tensor.
scale: The tensor with per-group scale factors.
"""
batch_present = len(tensor.shape) == 3
if not batch_present:
tensor = tensor.unsqueeze(0)
scale = scale.unsqueeze(0)
num_groups = scale.size(0) * scale.size(1)
input_dim, output_dim = tensor.shape
num_groups = scale.size(1) * scale.size(2)
batch, input_dim, output_dim = tensor.shape
num_bits = 8 if input_dim == scale.size(0) else 4
num_bits = 8 if input_dim == scale.size(1) else 4
q_range = 2**num_bits
if num_bits == 4:
input_dim *= 2
data = torch.empty((input_dim, output_dim),
data = torch.empty((batch, input_dim, output_dim),
dtype=scale.dtype,
device=tensor.device)
if num_bits == 8:
data.copy_(tensor)
data -= q_range // 2
else:
"""Unpack two 4-bit values from each byte.
"""
tensor = tensor.reshape(input_dim, output_dim // 2)
tensor = tensor.reshape(batch, input_dim, output_dim // 2)
for i in range(2):
data[:, i::2] = (tensor << 4 * (1 - i)) >> 4
data[:, :, i::2] = ((tensor << 4 *
(1 - i)) >> 4).to(torch.int8) - q_range // 2
"""Scale each input group with its scaling factor.
"""
scale = scale.reshape(num_groups, -1)
data = data.reshape(num_groups, -1)
scale = scale.reshape(batch, num_groups, -1)
data = data.reshape(batch, num_groups, -1)
data = torch.mul(data, scale)
input_deq = data.reshape((input_dim, output_dim)).contiguous()
input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
if not batch_present:
input_deq = input_deq.squeeze(0)
return input_deq
def fix_weights(layer: torch.nn.Module,
param_name: str,
reshape: bool = False):
"""torch.compile does not know how to deal with a Parameter subclass
(aka RTNParameter). As we don't really need RTNParameters for the
forward pass, we replace them with equivalent instances of Parameters.
"""
old_weight = getattr(layer, param_name)
assert isinstance(old_weight, RTNParameter)
data = old_weight.data.data
delattr(layer, param_name)
if reshape:
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
new_weight = Parameter(data=data, requires_grad=False)
layer.register_parameter(param_name, new_weight)