[Hardware/NVIDIA/Modelopt] Fix modelopt forward method for v1 torch.compile (#18101)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety 2025-05-13 19:33:00 -07:00 committed by GitHub
parent 176a95c670
commit 65f0f74b66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 10 deletions

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
""" CUTLASS based Fused MoE kernels."""
import os
from typing import Optional
import torch
@ -183,7 +184,8 @@ def cutlass_moe_fp8(
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
MAX_TOKENS_PER_EXPERT = 65536
MAX_TOKENS_PER_EXPERT = int(
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
@ -243,7 +245,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
== m), ("topk must be provided for each row of a")
assert (m <= MAX_TOKENS_PER_EXPERT), (
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m = {m}")
f" for cutlass_moe_fp4, observed m = {m}. Use"
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
out_dtype = a.dtype
num_topk = topk_ids.shape[1]

View File

@ -401,6 +401,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
layer.weight = Parameter(layer.weight.data, requires_grad=False)
if self.use_marlin:
prepare_fp4_layer_for_marlin(layer)
@ -426,11 +427,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
bias=bias)
output_dtype = x.dtype
# for input only the contracting dimension has a constraint.
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]
output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
@ -586,11 +583,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
# GEMM 1
assert torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
"w1_weight_scale_2 must match w3_weight_scale_2")
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
@ -616,6 +613,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
@ -633,6 +633,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
@ -694,7 +695,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
assert not apply_router_weight_on_input, (
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE.")
assert expert_map is None, ("Expert Parallelism /expert_map "
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")