[Bugfix] Remove NVFP4 scales assertions to fix load_format=dummy (#18861)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-05-30 01:37:36 -04:00 committed by GitHub
parent 77b6e74fe2
commit 4d0a1541be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 4 deletions

View File

@ -585,9 +585,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"w1_weight_scale_2 must match w3_weight_scale_2")
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1]):
logger.warning_once(
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected.")
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,

View File

@ -22,7 +22,12 @@ def is_fp4_marlin_supported():
def fp4_marlin_process_scales(marlin_scales):
assert (marlin_scales >= 0).all()
if not (marlin_scales >= 0).all():
logger.warning_once(
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
"negative scales. Accuracy will likely be degraded. This is "
"because it changes the scales from FP8-S1E4M3 to a special "
"FP8-S0E5M3 format to speedup the dequantization.")
# convert to half first, we would convert to fp8 later
marlin_scales = marlin_scales.to(torch.half)