[XPU] enable fp8 online streaming quantization (#30944)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma 2025-12-20 21:45:27 +08:00 committed by GitHub
parent 1501a4070e
commit 560ae9638c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 107 deletions

View File

@ -124,11 +124,13 @@ def get_fp8_moe_backend(
block_quant: bool,
moe_parallel_config: FusedMoEParallelConfig,
with_lora_support: bool,
) -> Fp8MoeBackend:
) -> Fp8MoeBackend | None:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if current_platform.is_xpu():
return None
if with_lora_support:
return Fp8MoeBackend.TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod()
return XPUFp8LinearMethod(fp8_config)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return XPUFp8MoEMethod(fp8_config, layer)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
current_platform.is_xpu()
or self.rocm_aiter_moe_enabled
or self.use_marlin
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
):

View File

@ -6,13 +6,8 @@ from typing import Any, Optional
import torch
from packaging import version
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._ipex_ops import ipex_ops as ops
from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.linear import (
LinearBase,
@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import (
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8LinearMethod,
Fp8OnlineMoEMethod,
)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
MIN_IPEX_VERSION = "2.6.0"
@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
# Update the layer with the new values.
layer.weight = Parameter(qweight, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
layer.input_scale = None
def apply(
@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
return output
class XPUFp8MoEMethod(FusedMoEMethodBase):
class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
super().__init__(quant_config, layer)
self.quant_config = quant_config
def create_weights(
self,
layer: Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.hidden_size = hidden_size
layer.num_experts = num_experts
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
# INPUT_SCALES
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
if not self.quant_config.is_checkpoint_fp8_serialized:
fp8_dtype = current_platform.fp8_dtype()
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
replace_parameter(layer, "w13_weight", w13_weight)
replace_parameter(layer, "w2_weight", w2_weight)
import intel_extension_for_pytorch as ipex
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts