Support W8A8 INT8 MoE for compressed-tensors (#16745)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-05-02 08:03:32 -06:00 committed by GitHub
parent 99404f53c7
commit 868c546da4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 136 additions and 1 deletions

View File

@ -34,6 +34,7 @@ __all__ = [
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
@ -71,6 +72,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@ -545,6 +548,138 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(

View File

@ -111,7 +111,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s,
i_zp,
symmetric=symmetric)