mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 21:23:06 +08:00
[XPU] enable fp8 online streaming quantization (#30944)
Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
parent
1501a4070e
commit
560ae9638c
@ -124,11 +124,13 @@ def get_fp8_moe_backend(
|
||||
block_quant: bool,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
with_lora_support: bool,
|
||||
) -> Fp8MoeBackend:
|
||||
) -> Fp8MoeBackend | None:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
if current_platform.is_xpu():
|
||||
return None
|
||||
if with_lora_support:
|
||||
return Fp8MoeBackend.TRITON
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
@ -292,6 +294,13 @@ class Fp8Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
return XPUFp8LinearMethod(fp8_config)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
|
||||
return XPUFp8MoEMethod(fp8_config, layer)
|
||||
elif isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
@ -1107,7 +1116,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
current_platform.is_xpu()
|
||||
or self.rocm_aiter_moe_enabled
|
||||
or self.use_marlin
|
||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
):
|
||||
|
||||
@ -6,13 +6,8 @@ from typing import Any, Optional
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
@ -24,14 +19,14 @@ from vllm.model_executor.layers.quantization import (
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8LinearMethod,
|
||||
Fp8OnlineMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MIN_IPEX_VERSION = "2.6.0"
|
||||
@ -309,44 +304,15 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
replace_parameter(layer, "weight", qweight.data)
|
||||
replace_parameter(layer, "weight_scale", weight_scale.data)
|
||||
layer.input_scale = None
|
||||
|
||||
def apply(
|
||||
@ -363,69 +329,14 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
return output
|
||||
|
||||
|
||||
class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
class XPUFp8MoEMethod(Fp8OnlineMoEMethod):
|
||||
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||
super().__init__(layer.moe_config)
|
||||
super().__init__(quant_config, layer)
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
||||
)
|
||||
# INPUT_SCALES
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
@ -448,8 +359,9 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
replace_parameter(layer, "w13_weight", w13_weight)
|
||||
replace_parameter(layer, "w2_weight", w2_weight)
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ep_rank_start = self.moe.ep_rank * self.moe.num_local_experts
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user