From f284d7bd0c55f929fa7912936b1d247089679191 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 17 Dec 2025 05:00:35 -0500 Subject: [PATCH] [Bug] Fix AttributeError: 'ColumnParallelLinear' object has no attribute `weight_scale_inv` (#30823) Signed-off-by: yewentao256 --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ea68745585160..bdc3d1fc7232d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -1437,14 +1437,17 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): layer.orig_dtype, layer.weight ) if should_use_deepgemm: + scale_attr = ( + "weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale" + ) dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block( wq=layer.weight.data, - ws=layer.weight_scale_inv.data, + ws=getattr(layer, scale_attr).data, quant_block_shape=tuple(layer.weight_block_size), use_e8m0=is_deep_gemm_e8m0_used(), ) replace_parameter(layer, "weight", dg_weight) - replace_parameter(layer, "weight_scale_inv", dg_weight_scale) + replace_parameter(layer, scale_attr, dg_weight_scale) def expert_weight_is_col_major(x: torch.Tensor) -> bool: