diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d905cc9eb0eff..a74f1f7233af0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -9,6 +9,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.quantization import (ActivationOrdering, QuantizationStrategy) +import vllm.envs as envs import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -20,10 +21,13 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, marlin_make_workspace_new, marlin_moe_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -114,10 +118,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if current_platform.is_rocm(): + self.use_marlin = False + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + params_dtype = torch.float8_e4m3fn # WEIGHTS @@ -280,6 +298,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin(layer, False) + # Activations not quantized for marlin. + del layer.w13_input_scale + del layer.w2_input_scale + def apply( self, layer: torch.nn.Module, @@ -311,6 +335,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) + if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") + return torch.ops.vllm.fused_marlin_moe( + x, + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=scalar_types.float8_e4m3fn.id, + global_num_experts=global_num_experts, + expert_map=expert_map) + return self.fused_experts_func( hidden_states=x, w1=layer.w13_weight, @@ -517,7 +559,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod): activation: str = "silu", ) -> torch.Tensor: - assert activation == "silu" + assert activation == "silu", ( + f"{activation} not supported for Cutlass MoE.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -942,11 +985,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: - assert activation == "silu", "Only SiLU activation is supported." - if apply_router_weight_on_input: - raise NotImplementedError( - "Apply router weight on input is not supported for " - "fused Marlin MoE method.") + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5b5f25909c33e..589ca7bed3294 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -811,6 +811,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ) if self.use_marlin: + assert activation == "silu", ( + f"{activation} not supported for Marlin MoE.") + assert not apply_router_weight_on_input, ( + "Apply router weight on input not supported for Marlin MoE.") return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index 08812debd321b..1f6e74244c5d4 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -268,6 +268,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module, tensor_list.append(marlin_scales) scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = fp8_fused_exponent_bias_into_scales(scales) scales = torch.nn.Parameter(scales, requires_grad=False) setattr(layer, name + "_weight_scale", scales)