mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 07:05:01 +08:00
[Bugfix] Fix Marlin NVFP4 for modelopt (#23659)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
082cc07ef8
commit
f9ca2b40a0
@ -891,7 +891,11 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
|
||||||
"Weight Block scale must be represented as FP8-E4M3")
|
"Weight Block scale must be represented as FP8-E4M3")
|
||||||
|
|
||||||
if self.backend == "flashinfer-trtllm":
|
if self.backend == "marlin":
|
||||||
|
prepare_fp4_layer_for_marlin(layer)
|
||||||
|
del layer.alpha
|
||||||
|
del layer.input_scale
|
||||||
|
elif self.backend == "flashinfer-trtllm":
|
||||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||||
# layout but we use our own quantization so we have to call
|
# layout but we use our own quantization so we have to call
|
||||||
@ -916,11 +920,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
if self.backend == "marlin":
|
|
||||||
prepare_fp4_layer_for_marlin(layer)
|
|
||||||
del layer.alpha
|
|
||||||
del layer.input_scale
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1312,6 +1311,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
del layer.w2_weight_scale
|
del layer.w2_weight_scale
|
||||||
del layer.w13_weight
|
del layer.w13_weight
|
||||||
del layer.w13_weight_scale
|
del layer.w13_weight_scale
|
||||||
|
elif self.use_marlin:
|
||||||
|
# Marlin processing
|
||||||
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
|
del layer.g1_alphas
|
||||||
|
del layer.g2_alphas
|
||||||
|
del layer.w13_input_scale_quant
|
||||||
|
del layer.w2_input_scale_quant
|
||||||
else:
|
else:
|
||||||
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
# Non-TRT-LLM processing (Cutlass or non-flashinfer)
|
||||||
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
assert (layer.w13_weight_scale.shape[2] % 16 == 0), (
|
||||||
@ -1333,13 +1339,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
if self.use_marlin:
|
|
||||||
prepare_moe_fp4_layer_for_marlin(layer)
|
|
||||||
del layer.g1_alphas
|
|
||||||
del layer.g2_alphas
|
|
||||||
del layer.w13_input_scale_quant
|
|
||||||
del layer.w2_input_scale_quant
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user