From d0502b4928fb683491952c6cd4f31b3d63e6d25c Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 15 Dec 2025 09:54:53 -0500 Subject: [PATCH] [MoE][Refactor 1/N] Separate Online Quantization (#30627) Signed-off-by: Robert Shaw Co-authored-by: Robert Shaw --- .../model_executor/layers/quantization/fp8.py | 243 +++++++++++------- 1 file changed, 154 insertions(+), 89 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 6909bac1efc7c..f2b66a2beb6d7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -332,7 +332,10 @@ class Fp8Config(QuantizationConfig): fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) - moe_quant_method = Fp8MoEMethod(self, layer) + if self.is_checkpoint_fp8_serialized: + moe_quant_method = Fp8MoEMethod(self, layer) + else: + moe_quant_method = Fp8OnlineMoEMethod(self, layer) moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return moe_quant_method elif isinstance(layer, Attention): @@ -745,8 +748,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.orig_dtype = params_dtype layer.weight_block_size = None - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn + assert self.quant_config.is_checkpoint_fp8_serialized + params_dtype = torch.float8_e4m3fn + if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size @@ -773,41 +777,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): f"weight quantization block_k = {block_k}." ) - # if we are doing online quantization, patch the weight - # loaded to call `process_weights_after_loading` in a streaming fashion - # as soon as the last weight chunk is loaded - if not self.quant_config.is_checkpoint_fp8_serialized: - weight_loader = extra_weight_attrs["weight_loader"] - # create a new holder to prevent modifying behavior of any other - # objects which might depend on the old one - new_extra_weight_attrs = extra_weight_attrs - - def patched_weight_loader(param, loaded_weight, *args, **kwargs): - # load the current weight chunk - res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] - - # add a counter to track how many elements we have updated - if not hasattr(layer, "_loaded_numel"): - layer._loaded_numel = 0 - layer._loaded_numel += loaded_weight.numel() - - # if we have loaded all of the elements, call - # process_weights_after_loading - target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() - if layer._loaded_numel == target_loaded_numel: - self.process_weights_after_loading(layer) - - # Delete the bookkeeping - del layer._loaded_numel - # Prevent the usual `process_weights_after_loading` call - # from doing anything - layer._already_called_process_weights_after_loading = True - - return res - - new_extra_weight_attrs["weight_loader"] = patched_weight_loader - extra_weight_attrs = new_extra_weight_attrs - # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -875,21 +844,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -986,45 +945,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_weight_scale_inv = Parameter( dg_w2_weight_scale_inv, requires_grad=False ) - - # If checkpoint is fp16, quantize in place. - elif not self.quant_config.is_checkpoint_fp8_serialized: - fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - replace_parameter( - layer, - "w13_weight_scale", - torch.ones( - layer.local_num_experts, - dtype=torch.float32, - device=w13_weight.device, - ), - ) - for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w2_weight", w2_weight) - - if self.rocm_aiter_moe_enabled: - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight, layer.w2_weight - ) - - replace_parameter(layer, "w13_weight", shuffled_w13) - replace_parameter(layer, "w2_weight", shuffled_w2) - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. else: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. @@ -1387,6 +1307,151 @@ class Fp8MoEMethod(FusedMoEMethodBase): return result +class Fp8OnlineMoEMethod(Fp8MoEMethod): + """MoE method for online FP8 quantization. + Supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): + super().__init__(quant_config, layer) + assert not quant_config.is_checkpoint_fp8_serialized + assert quant_config.activation_scheme == "dynamic" + assert quant_config.weight_block_size is None + assert self.flashinfer_moe_backend is None + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.intermediate_size_per_partition = intermediate_size_per_partition + layer.hidden_size = hidden_size + layer.num_experts = num_experts + layer.orig_dtype = params_dtype + layer.weight_block_size = None + + # We are doing online quantization, patch the weight loaded + # to call `process_weights_after_loading` in a streaming fashion + # as soon as the last weight chunk is loaded. + weight_loader = extra_weight_attrs["weight_loader"] + # create a new holder to prevent modifying behavior of any other + # objects which might depend on the old one + new_extra_weight_attrs = extra_weight_attrs + + def patched_weight_loader(param, loaded_weight, *args, **kwargs): + # load the current weight chunk + res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc] + + # add a counter to track how many elements we have updated + if not hasattr(layer, "_loaded_numel"): + layer._loaded_numel = 0 + layer._loaded_numel += loaded_weight.numel() + + # if we have loaded all of the elements, call + # process_weights_after_loading + target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel() + if layer._loaded_numel == target_loaded_numel: + self.process_weights_after_loading(layer) + + # Delete the bookkeeping + del layer._loaded_numel + # Prevent the usual `process_weights_after_loading` call + # from doing anything + layer._already_called_process_weights_after_loading = True + + return res + + new_extra_weight_attrs["weight_loader"] = patched_weight_loader + extra_weight_attrs = new_extra_weight_attrs + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.rocm_aiter_moe_enabled = False + + def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + # Lazy import to avoid importing triton too early. + self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + + # If checkpoint is fp16, quantize in place. + fp8_dtype = current_platform.fp8_dtype() + w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + for expert in range(layer.local_num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + + # Reshuffle weights for AITER if needed. + if self.rocm_aiter_moe_enabled: + shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( + layer.w13_weight, layer.w2_weight + ) + replace_parameter(layer, "w13_weight", shuffled_w13) + replace_parameter(layer, "w2_weight", shuffled_w2) + + # Rushuffle weights for MARLIN if needed. + if self.use_marlin: + prepare_moe_fp8_layer_for_marlin( + layer, False, input_dtype=self.marlin_input_dtype + ) + + class Fp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints.