From 560ae9638c00e761d6f1ca33882249dbfebbe8aa Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Sat, 20 Dec 2025 21:45:27 +0800 Subject: [PATCH] [XPU] enable fp8 online streaming quantization (#30944) Signed-off-by: Yan Ma --- .../model_executor/layers/quantization/fp8.py | 14 +- .../layers/quantization/ipex_quant.py | 122 +++--------------- 2 files changed, 29 insertions(+), 107 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index a86fb3d309525..30ca64238ae94 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -124,11 +124,13 @@ def get_fp8_moe_backend( block_quant: bool, moe_parallel_config: FusedMoEParallelConfig, with_lora_support: bool, -) -> Fp8MoeBackend: +) -> Fp8MoeBackend | None: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ + if current_platform.is_xpu(): + return None if with_lora_support: return Fp8MoeBackend.TRITON # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. @@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig): return UnquantizedLinearMethod() return XPUFp8LinearMethod(fp8_config) elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignored_layers, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedFusedMoEMethod(layer.moe_config) + return XPUFp8MoEMethod(fp8_config, layer) elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) @@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: if ( - self.rocm_aiter_moe_enabled + current_platform.is_xpu() + or self.rocm_aiter_moe_enabled or self.use_marlin or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index f33ee43727f19..9de2924ec71b1 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -6,13 +6,8 @@ from typing import Any, Optional import torch from packaging import version from torch.nn import Module -from torch.nn.parameter import Parameter from vllm._ipex_ops import ipex_ops as ops -from vllm.model_executor.layers.fused_moe import ( - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, -) from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.linear import ( LinearBase, @@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import ( QuantizationMethods, ) from vllm.model_executor.layers.quantization.awq import AWQLinearMethod -from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod +from vllm.model_executor.layers.quantization.fp8 import ( + Fp8Config, + Fp8LinearMethod, + Fp8OnlineMoEMethod, +) from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - maybe_create_device_identity, -) -from vllm.model_executor.parameter import ModelWeightParameter -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import replace_parameter from vllm.platforms import current_platform MIN_IPEX_VERSION = "2.6.0" @@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod): def __init__(self, quant_config: Fp8Config): super().__init__(quant_config) - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - maybe_create_device_identity() - - output_size_per_partition = sum(output_partition_sizes) - weight_loader = extra_weight_attrs.get("weight_loader") - layer.logical_widths = output_partition_sizes - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - layer.weight_block_size = None - weight = ModelWeightParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) # Update the layer with the new values. - layer.weight = Parameter(qweight, requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) + replace_parameter(layer, "weight", qweight.data) + replace_parameter(layer, "weight_scale", weight_scale.data) layer.input_scale = None def apply( @@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod): return output -class XPUFp8MoEMethod(FusedMoEMethodBase): +class XPUFp8MoEMethod(Fp8OnlineMoEMethod): def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): - super().__init__(layer.moe_config) + super().__init__(quant_config, layer) self.quant_config = quant_config - def create_weights( - self, - layer: Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - layer.intermediate_size_per_partition = intermediate_size_per_partition - layer.hidden_size = hidden_size - layer.num_experts = num_experts - layer.orig_dtype = params_dtype - layer.weight_block_size = None - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - # INPUT_SCALES - layer.w13_input_scale = None - layer.w2_input_scale = None - def process_weights_after_loading(self, layer: Module) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return if not self.quant_config.is_checkpoint_fp8_serialized: fp8_dtype = current_platform.fp8_dtype() w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) @@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase): w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w2_weight", w2_weight) + import intel_extension_for_pytorch as ipex ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts