mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 19:27:39 +08:00
[Bug] Fix AttributeError: 'ColumnParallelLinear' object has no attribute weight_scale_inv (#30823)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
53cd7f868b
commit
f284d7bd0c
@ -1437,14 +1437,17 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
layer.orig_dtype, layer.weight
|
||||
)
|
||||
if should_use_deepgemm:
|
||||
scale_attr = (
|
||||
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
|
||||
)
|
||||
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
|
||||
wq=layer.weight.data,
|
||||
ws=layer.weight_scale_inv.data,
|
||||
ws=getattr(layer, scale_attr).data,
|
||||
quant_block_shape=tuple(layer.weight_block_size),
|
||||
use_e8m0=is_deep_gemm_e8m0_used(),
|
||||
)
|
||||
replace_parameter(layer, "weight", dg_weight)
|
||||
replace_parameter(layer, "weight_scale_inv", dg_weight_scale)
|
||||
replace_parameter(layer, scale_attr, dg_weight_scale)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user