mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 07:05:02 +08:00
[Feature] Add support for MoE models in the calibration-free RTN-based quantization (#20766)
Signed-off-by: Alex Kogan <alex.kogan@oracle.com>
This commit is contained in:
parent
f1b286b2fb
commit
7ae75fa6d0
@ -8,7 +8,10 @@ import pytest
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
|
||||||
MODELS = ["microsoft/Phi-3-mini-4k-instruct"]
|
MODELS = [
|
||||||
|
"microsoft/Phi-3-mini-4k-instruct", # dense model
|
||||||
|
"ai21labs/Jamba-tiny-dev", # MoE model
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("rtn"),
|
@pytest.mark.skipif(not is_quant_method_supported("rtn"),
|
||||||
|
|||||||
@ -3,18 +3,19 @@
|
|||||||
# Copyright © 2025, Oracle and/or its affiliates.
|
# Copyright © 2025, Oracle and/or its affiliates.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
"""By default, use 8 bit as target precision, but it can be
|
"""By default, use 8 bit as target precision, but it can be
|
||||||
@ -71,9 +72,11 @@ class RTNConfig(QuantizationConfig):
|
|||||||
return cls(weight_bits, group_size)
|
return cls(weight_bits, group_size)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["RTNLinearMethod"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return RTNLinearMethod(self)
|
return RTNLinearMethod(self)
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return RTNMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -94,11 +97,18 @@ class RTNTensor:
|
|||||||
self.data.narrow(dim, start // factor, length // factor),
|
self.data.narrow(dim, start // factor, length // factor),
|
||||||
self.scale.narrow(dim, start, length), self.quant_config)
|
self.scale.narrow(dim, start, length), self.quant_config)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return RTNTensor(self.data[key], self.scale[key], self.quant_config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
shape = self.data.shape
|
shape = self.data.shape
|
||||||
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||||
return torch.Size((shape[0] * factor, shape[1]))
|
batch_present = len(shape) == 3
|
||||||
|
if batch_present:
|
||||||
|
return torch.Size((shape[0], shape[1] * factor, shape[2]))
|
||||||
|
else:
|
||||||
|
return torch.Size((shape[0] * factor, shape[1]))
|
||||||
|
|
||||||
def copy_(self, loaded_weight: torch.Tensor) -> None:
|
def copy_(self, loaded_weight: torch.Tensor) -> None:
|
||||||
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
|
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
|
||||||
@ -165,7 +175,7 @@ class RTNLinearMethod(LinearMethodBase):
|
|||||||
weight = RTNParameter(data=torch.empty(output_size_per_partition //
|
weight = RTNParameter(data=torch.empty(output_size_per_partition //
|
||||||
factor,
|
factor,
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
dtype=torch.int8),
|
dtype=torch.uint8),
|
||||||
scale=scale,
|
scale=scale,
|
||||||
quant_config=self.quant_config)
|
quant_config=self.quant_config)
|
||||||
|
|
||||||
@ -180,18 +190,7 @@ class RTNLinearMethod(LinearMethodBase):
|
|||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
"""torch.compile does not know how to deal with a Parameter subclass
|
fix_weights(layer, "weight")
|
||||||
(aka RTNParameter). As we don't really need RTNParameters for the
|
|
||||||
forward pass, we replace them with equivalent instances of Parameters.
|
|
||||||
"""
|
|
||||||
old_weight = layer.weight
|
|
||||||
assert isinstance(old_weight, RTNParameter)
|
|
||||||
data = old_weight.data.data
|
|
||||||
|
|
||||||
delattr(layer, "weight")
|
|
||||||
|
|
||||||
new_weight = Parameter(data=data, requires_grad=False)
|
|
||||||
layer.register_parameter("weight", new_weight)
|
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -209,6 +208,128 @@ class RTNLinearMethod(LinearMethodBase):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class RTNMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, quant_config: RTNConfig):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
||||||
|
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
num_groups_per_col = (hidden_size // self.quant_config.group_size
|
||||||
|
if self.quant_config.group_size != -1 else 1)
|
||||||
|
w13_scale = Parameter(
|
||||||
|
torch.empty(num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
num_groups_per_col,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.register_parameter("w13_scale", w13_scale)
|
||||||
|
|
||||||
|
w13_weight = RTNParameter(data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition // factor,
|
||||||
|
hidden_size,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
scale=w13_scale,
|
||||||
|
quant_config=self.quant_config)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
num_groups_per_col = (intermediate_size_per_partition //
|
||||||
|
self.quant_config.group_size
|
||||||
|
if self.quant_config.group_size != -1 else 1)
|
||||||
|
w2_scale = Parameter(torch.zeros(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
num_groups_per_col,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_scale", w2_scale)
|
||||||
|
|
||||||
|
w2_weight = RTNParameter(data=torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size // factor,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
scale=w2_scale,
|
||||||
|
quant_config=self.quant_config)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
weight_bits = self.quant_config.weight_bits
|
||||||
|
fix_weights(layer, "w13_weight", weight_bits == 4)
|
||||||
|
fix_weights(layer, "w2_weight", weight_bits == 4)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if enable_eplb:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EPLB not supported for `RTNMoEMethod` yet.")
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
|
weight_bits = self.quant_config.weight_bits
|
||||||
|
group_size = self.quant_config.group_size
|
||||||
|
|
||||||
|
ret = fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
use_int4_w4a16=weight_bits == 4,
|
||||||
|
use_int8_w8a16=weight_bits == 8,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
w1_scale=layer.w13_scale,
|
||||||
|
w2_scale=layer.w2_scale,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=expert_map,
|
||||||
|
block_shape=[0, group_size])
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
||||||
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Quantize a tensor using per-group static scaling factor.
|
"""Quantize a tensor using per-group static scaling factor.
|
||||||
@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
|||||||
If equal to -1, each row in the input tensor is treated
|
If equal to -1, each row in the input tensor is treated
|
||||||
as one group.
|
as one group.
|
||||||
"""
|
"""
|
||||||
|
batch_present = len(tensor.shape) == 3
|
||||||
|
if not batch_present:
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
|
||||||
q_range = 2**num_bits
|
q_range = 2**num_bits
|
||||||
num_groups = (tensor.shape[0] * tensor.shape[1] //
|
num_groups = (tensor.shape[1] * tensor.shape[2] //
|
||||||
group_size if group_size != -1 else tensor.shape[0])
|
group_size if group_size != -1 else tensor.shape[1])
|
||||||
"""Calculate a scaling factor per input group.
|
"""Calculate a scaling factor per input group.
|
||||||
"""
|
"""
|
||||||
input_flat = tensor.reshape(num_groups, -1)
|
input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
|
||||||
input_min = torch.min(input_flat, dim=1, keepdim=True)[0]
|
input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
|
||||||
input_max = torch.max(input_flat, dim=1, keepdim=True)[0]
|
input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
|
||||||
input_max_abs = torch.max(input_min.abs(), input_max.abs())
|
input_max_abs = torch.max(input_min.abs(), input_max.abs())
|
||||||
scale = (input_max_abs * 2.0 / (q_range - 1))
|
scale = (input_max_abs * 2.0 / (q_range - 1))
|
||||||
"""Scale each input group, truncate and round to the nearest integer.
|
"""Scale each input group, round to the nearest integer, shift
|
||||||
|
the range and truncate.
|
||||||
"""
|
"""
|
||||||
scaled_input = input_flat / scale
|
scaled_input = input_flat / scale
|
||||||
scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1)
|
|
||||||
scaled_input = scaled_input.round()
|
scaled_input = scaled_input.round()
|
||||||
|
scaled_input += q_range // 2
|
||||||
|
scaled_input = scaled_input.clamp(0, q_range - 1)
|
||||||
|
|
||||||
scale = scale.reshape(tensor.shape[0], -1).contiguous()
|
scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
|
||||||
inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8)
|
inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
|
||||||
inputs_q = inputs_q.contiguous()
|
inputs_q = inputs_q.contiguous()
|
||||||
|
|
||||||
if num_bits == 4:
|
if num_bits == 4:
|
||||||
"""Pack two 4-bit values into each byte.
|
"""Pack two 4-bit values into each byte.
|
||||||
"""
|
"""
|
||||||
inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf)
|
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf)
|
||||||
inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1])
|
inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2,
|
||||||
|
tensor.shape[2])
|
||||||
inputs_q = inputs_q.contiguous()
|
inputs_q = inputs_q.contiguous()
|
||||||
|
|
||||||
|
if not batch_present:
|
||||||
|
inputs_q = inputs_q.squeeze(0)
|
||||||
|
scale = scale.squeeze(0)
|
||||||
|
|
||||||
return inputs_q, scale
|
return inputs_q, scale
|
||||||
|
|
||||||
|
|
||||||
@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|||||||
tensor: The input tensor.
|
tensor: The input tensor.
|
||||||
scale: The tensor with per-group scale factors.
|
scale: The tensor with per-group scale factors.
|
||||||
"""
|
"""
|
||||||
|
batch_present = len(tensor.shape) == 3
|
||||||
|
if not batch_present:
|
||||||
|
tensor = tensor.unsqueeze(0)
|
||||||
|
scale = scale.unsqueeze(0)
|
||||||
|
|
||||||
num_groups = scale.size(0) * scale.size(1)
|
num_groups = scale.size(1) * scale.size(2)
|
||||||
input_dim, output_dim = tensor.shape
|
batch, input_dim, output_dim = tensor.shape
|
||||||
|
|
||||||
num_bits = 8 if input_dim == scale.size(0) else 4
|
num_bits = 8 if input_dim == scale.size(1) else 4
|
||||||
|
q_range = 2**num_bits
|
||||||
if num_bits == 4:
|
if num_bits == 4:
|
||||||
input_dim *= 2
|
input_dim *= 2
|
||||||
|
|
||||||
data = torch.empty((input_dim, output_dim),
|
data = torch.empty((batch, input_dim, output_dim),
|
||||||
dtype=scale.dtype,
|
dtype=scale.dtype,
|
||||||
device=tensor.device)
|
device=tensor.device)
|
||||||
|
|
||||||
if num_bits == 8:
|
if num_bits == 8:
|
||||||
data.copy_(tensor)
|
data.copy_(tensor)
|
||||||
|
data -= q_range // 2
|
||||||
else:
|
else:
|
||||||
"""Unpack two 4-bit values from each byte.
|
"""Unpack two 4-bit values from each byte.
|
||||||
"""
|
"""
|
||||||
tensor = tensor.reshape(input_dim, output_dim // 2)
|
tensor = tensor.reshape(batch, input_dim, output_dim // 2)
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
data[:, i::2] = (tensor << 4 * (1 - i)) >> 4
|
data[:, :, i::2] = ((tensor << 4 *
|
||||||
|
(1 - i)) >> 4).to(torch.int8) - q_range // 2
|
||||||
"""Scale each input group with its scaling factor.
|
"""Scale each input group with its scaling factor.
|
||||||
"""
|
"""
|
||||||
scale = scale.reshape(num_groups, -1)
|
scale = scale.reshape(batch, num_groups, -1)
|
||||||
data = data.reshape(num_groups, -1)
|
data = data.reshape(batch, num_groups, -1)
|
||||||
data = torch.mul(data, scale)
|
data = torch.mul(data, scale)
|
||||||
|
|
||||||
input_deq = data.reshape((input_dim, output_dim)).contiguous()
|
input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
|
||||||
|
if not batch_present:
|
||||||
|
input_deq = input_deq.squeeze(0)
|
||||||
|
|
||||||
return input_deq
|
return input_deq
|
||||||
|
|
||||||
|
|
||||||
|
def fix_weights(layer: torch.nn.Module,
|
||||||
|
param_name: str,
|
||||||
|
reshape: bool = False):
|
||||||
|
"""torch.compile does not know how to deal with a Parameter subclass
|
||||||
|
(aka RTNParameter). As we don't really need RTNParameters for the
|
||||||
|
forward pass, we replace them with equivalent instances of Parameters.
|
||||||
|
"""
|
||||||
|
old_weight = getattr(layer, param_name)
|
||||||
|
assert isinstance(old_weight, RTNParameter)
|
||||||
|
data = old_weight.data.data
|
||||||
|
|
||||||
|
delattr(layer, param_name)
|
||||||
|
|
||||||
|
if reshape:
|
||||||
|
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
|
||||||
|
new_weight = Parameter(data=data, requires_grad=False)
|
||||||
|
layer.register_parameter(param_name, new_weight)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user