mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Bugfix] Remove NVFP4 scales assertions to fix load_format=dummy (#18861)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
77b6e74fe2
commit
4d0a1541be
@ -585,9 +585,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# GEMM 1
|
||||
assert torch.allclose(
|
||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
|
||||
"w1_weight_scale_2 must match w3_weight_scale_2")
|
||||
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
||||
layer.w13_weight_scale_2[:, 1]):
|
||||
logger.warning_once(
|
||||
"w1_weight_scale_2 must match w3_weight_scale_2. "
|
||||
"Accuracy may be affected.")
|
||||
|
||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||
|
||||
@ -22,7 +22,12 @@ def is_fp4_marlin_supported():
|
||||
|
||||
|
||||
def fp4_marlin_process_scales(marlin_scales):
|
||||
assert (marlin_scales >= 0).all()
|
||||
if not (marlin_scales >= 0).all():
|
||||
logger.warning_once(
|
||||
"NVFP4 Marlin assumes the scales to be >=0, but has encountered "
|
||||
"negative scales. Accuracy will likely be degraded. This is "
|
||||
"because it changes the scales from FP8-S1E4M3 to a special "
|
||||
"FP8-S0E5M3 format to speedup the dequantization.")
|
||||
|
||||
# convert to half first, we would convert to fp8 later
|
||||
marlin_scales = marlin_scales.to(torch.half)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user