mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 07:05:01 +08:00
[Bugfix] Fix Marlin NVFP4 for modelopt (#23659)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
082cc07ef8
commit
f9ca2b40a0
@ -891,7 +891,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||
"Weight Block scale must be represented as FP8-E4M3")
|
||||
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
elif self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
@ -916,11 +920,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.backend == "marlin":
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
del layer.alpha
|
||||
del layer.input_scale
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1312,6 +1311,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
del layer.w2_weight_scale
|
||||
del layer.w13_weight
|
||||
del layer.w13_weight_scale
|
||||
elif self.use_marlin:
|
||||
# Marlin processing
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
else:
|
||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||
@ -1333,13 +1339,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
del layer.g1_alphas
|
||||
del layer.g2_alphas
|
||||
del layer.w13_input_scale_quant
|
||||
del layer.w2_input_scale_quant
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user