[Bugfix] Fix Marlin NVFP4 for modelopt (#23659)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-08-27 17:48:16 -04:00 committed by GitHub
parent 082cc07ef8
commit f9ca2b40a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -891,7 +891,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
if self.backend == "flashinfer-trtllm":
if self.backend == "marlin":
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
elif self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
@ -916,11 +920,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
if self.backend == "marlin":
prepare_fp4_layer_for_marlin(layer)
del layer.alpha
del layer.input_scale
def apply(
self,
layer: torch.nn.Module,
@ -1312,6 +1311,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
elif self.use_marlin:
# Marlin processing
prepare_moe_fp4_layer_for_marlin(layer)
del layer.g1_alphas
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
else:
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
@ -1333,13 +1339,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data,
requires_grad=False)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
del layer.g1_alphas
del layer.g2_alphas
del layer.w13_input_scale_quant
del layer.w2_input_scale_quant
def apply(
self,
layer: torch.nn.Module,