diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 721e36af2b28f..ae16a20cfaab9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -34,6 +34,7 @@ __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Fp8MoECutlassMethod", + "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", ] @@ -71,6 +72,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8MoEMethod(quant_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -545,6 +548,138 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): ) +class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + hidden_size, + 1, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_int8_w8a8=True, + per_channel_quant=True, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def __init__( diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py index 2bf21a05c46d9..047724129522b 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -111,7 +111,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - x_q, x_s, x_zp = ops.scaled_int8_quant(x, + x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), i_s, i_zp, symmetric=symmetric)