diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 8f69636dda7bf..17b41e8a1c23c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -45,7 +45,6 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", - "CompressedTensorsW8A8Fp8MoECutlassMethod", "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", "CompressedTensorsW4A4MoeMethod" @@ -84,9 +83,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4MoeMethod() elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)): - return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) - elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) + or quant_config._is_fp8_w8a8(weight_quant, input_quant)): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8MoEMethod(quant_config) @@ -378,6 +376,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() + # cutlass path + self.is_fp8_w8a8_sm100 = quant_config._is_fp8_w8a8_sm100( + self.weight_quant, self.input_quant) + self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( + self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) + self.fused_experts = None # type: ignore[assignment] + self.disable_expert_map = False + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -558,6 +564,34 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: + # cutlass path + if self.use_cutlass: + from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 + + use_batched_format = (prepare_finalize.activation_format == + FusedMoEActivationFormat.BatchedExperts) + + num_dispatchers = prepare_finalize.num_dispatchers() + num_experts = (moe.num_local_experts + if use_batched_format else moe.num_experts) + + logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) + + experts = CutlassExpertsFp8( + num_experts, + moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + num_dispatchers=num_dispatchers, + use_batched_format=use_batched_format, + ) + + self.disable_expert_map = (num_dispatchers > 1 + or not experts.supports_expert_map()) + + return experts + + # triton path from vllm.model_executor.layers.fused_moe import TritonExperts from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts) @@ -629,6 +663,68 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): indices_type=self.topk_indices_dtype, ) + # cutlass path + if self.use_cutlass: + per_act_token = ( + self.input_quant.strategy == QuantizationStrategy.TOKEN) + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + + # small-batch fallback on SM100 + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + if self.fused_experts is None: + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + per_act_token=per_act_token, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + else: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts_func( hidden_states=x, @@ -685,291 +781,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): a2_scale=layer.w2_input_scale) -class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): - - def __init__( - self, - quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") - - per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy - == QuantizationStrategy.TENSOR) - per_channel = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN) - if not (per_tensor or per_channel): - raise ValueError( - "For FP8 Fused MoE layers, we require per tensor " - "or channelwise, dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}") - - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales and per_channel: - raise ValueError( - "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization.") - - self.topk_indices_dtype = None - self.fused_experts = None # type: ignore - self.disable_expert_map = False - self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100( - self.weight_quant, self.input_quant) - - def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size_per_partition: int, - params_dtype: torch.dtype, **extra_weight_attrs): - - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter(torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype), - requires_grad=False) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - # Allocate 2 scales for w1 and w3 respectively. - # They are combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, 2, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-TENSOR quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - w13_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter(torch.ones( - num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.static_input_scales: - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.static_input_scales: - assert self.input_quant.strategy == QuantizationStrategy.TENSOR - if (layer.w13_input_scale is None or layer.w2_input_scale is None): - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None.") - if (not all_close_1d(layer.w13_input_scale) - or not all_close_1d(layer.w2_input_scale)): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer.") - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False) - - # For Per-TENSOR case, Fp8 moe kernel needs single weight scale - # for w13 per expert. Use max then dequant and requant each expert. - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start:start + - shard_size, :], - layer.w13_weight_scale[expert_id][shard_id]) - layer.w13_weight[expert_id][ - start:start + shard_size, :], _ = ops.scaled_fp8_quant( - dq_weight, max_w13_scales[expert_id]) - start += shard_size - layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, - requires_grad=False) - - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - moe: FusedMoEConfig, - ) -> FusedMoEPermuteExpertsUnpermute: - from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 - - use_batched_format = (prepare_finalize.activation_format == - FusedMoEActivationFormat.BatchedExperts) - - num_dispatchers = prepare_finalize.num_dispatchers() - - num_experts = (moe.num_local_experts - if use_batched_format else moe.num_experts) - - logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) - - experts = CutlassExpertsFp8( - num_experts, - moe.in_dtype, - self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL, - num_dispatchers=num_dispatchers, - use_batched_format=use_batched_format, - ) - - self.disable_expert_map = (num_dispatchers > 1 - or not experts.supports_expert_map()) - - return experts - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if enable_eplb: - raise NotImplementedError( - "EPLB not supported for " - "`CompressedTensorsW8A8Fp8MoECutlassMethod` yet.") - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - per_act_token = ( - self.input_quant.strategy == QuantizationStrategy.TOKEN) - per_channel_quant = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) - # Triton fused_experts is faster in small batch sizes on SM100. - # Fall back to fused_experts in small batch sizes. - if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: - from vllm.model_executor.layers.fused_moe import fused_experts - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - use_fp8_w8a8=True, - per_channel_quant=per_channel_quant, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) - if self.fused_experts is None: - # If no modular kernel is provided, use cutlass_moe_fp8 - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) - return cutlass_moe_fp8( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - per_act_token=per_act_token, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - else: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): def __init__(