[XPU] fix broken fp8 online quantization for XPU platform (#30831)

Signed-off-by: Yan Ma <yan.ma@intel.com>
This commit is contained in:
Yan Ma 2025-12-17 16:28:13 +08:00 committed by GitHub
parent 0cd5353644
commit 4f735babb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -27,6 +27,10 @@ from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@ -305,6 +309,37 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: