mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
[Hardware/NVIDIA/Modelopt] Fix modelopt forward method for v1 torch.compile (#18101)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
parent
176a95c670
commit
65f0f74b66
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
""" CUTLASS based Fused MoE kernels."""
|
""" CUTLASS based Fused MoE kernels."""
|
||||||
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -183,7 +184,8 @@ def cutlass_moe_fp8(
|
|||||||
|
|
||||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
MAX_TOKENS_PER_EXPERT = 65536
|
MAX_TOKENS_PER_EXPERT = int(
|
||||||
|
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
|
||||||
|
|
||||||
|
|
||||||
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||||
@ -243,7 +245,8 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
|||||||
== m), ("topk must be provided for each row of a")
|
== m), ("topk must be provided for each row of a")
|
||||||
assert (m <= MAX_TOKENS_PER_EXPERT), (
|
assert (m <= MAX_TOKENS_PER_EXPERT), (
|
||||||
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
|
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
|
||||||
f" for cutlass_moe_fp4, observed m = {m}")
|
f" for cutlass_moe_fp4, observed m = {m}. Use"
|
||||||
|
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
|
||||||
out_dtype = a.dtype
|
out_dtype = a.dtype
|
||||||
num_topk = topk_ids.shape[1]
|
num_topk = topk_ids.shape[1]
|
||||||
|
|
||||||
|
|||||||
@ -401,6 +401,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_fp4_layer_for_marlin(layer)
|
prepare_fp4_layer_for_marlin(layer)
|
||||||
@ -426,11 +427,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
|
|||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
output_dtype = x.dtype
|
output_dtype = x.dtype
|
||||||
|
output_shape = [x.shape[0], layer.weight.shape[0]]
|
||||||
# for input only the contracting dimension has a constraint.
|
|
||||||
x_m, _ = x.shape
|
|
||||||
w_n, _ = layer.weight.shape
|
|
||||||
output_shape = [x_m, w_n]
|
|
||||||
|
|
||||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||||
s_quant = 1 / layer.input_scale
|
s_quant = 1 / layer.input_scale
|
||||||
@ -586,11 +583,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# GEMM 1
|
|
||||||
|
|
||||||
|
# GEMM 1
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
|
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), (
|
||||||
"Expected w1_weight_scale_2 to equal w3_weight_scale_2")
|
"w1_weight_scale_2 must match w3_weight_scale_2")
|
||||||
|
|
||||||
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
|
||||||
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2,
|
||||||
@ -616,6 +613,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale_quant = Parameter(
|
layer.w13_input_scale_quant = Parameter(
|
||||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(layer.w13_weight.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
# GEMM 2
|
# GEMM 2
|
||||||
layer.g2_alphas = Parameter(
|
layer.g2_alphas = Parameter(
|
||||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||||
@ -633,6 +633,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
|
|
||||||
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_moe_fp4_layer_for_marlin(layer)
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
@ -694,7 +695,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
assert not apply_router_weight_on_input, (
|
assert not apply_router_weight_on_input, (
|
||||||
"Router weight on input is not "
|
"Router weight on input is not "
|
||||||
"supported for ModelOptNvFp4FusedMoE.")
|
"supported for ModelOptNvFp4FusedMoE.")
|
||||||
assert expert_map is None, ("Expert Parallelism /expert_map "
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
"is currently not supported for "
|
"is currently not supported for "
|
||||||
"ModelOptNvFp4FusedMoE.")
|
"ModelOptNvFp4FusedMoE.")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user