diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1c5680f952ab..2abe16a08a26 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -585,9 +585,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 - assert torch.allclose( - layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), ( - "w1_weight_scale_2 must match w3_weight_scale_2") + if not torch.allclose(layer.w13_weight_scale_2[:, 0], + layer.w13_weight_scale_2[:, 1]): + logger.warning_once( + "w1_weight_scale_2 must match w3_weight_scale_2. " + "Accuracy may be affected.") w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py index 15177af58ae6..13dcdc00a215 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py @@ -22,7 +22,12 @@ def is_fp4_marlin_supported(): def fp4_marlin_process_scales(marlin_scales): - assert (marlin_scales >= 0).all() + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") # convert to half first, we would convert to fp8 later marlin_scales = marlin_scales.to(torch.half)