[ModelOpt] Introduce VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE env var to control blockscale tensor allocation (#18160)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety 2025-05-23 15:52:20 -07:00 committed by GitHub
parent 7d9216495c
commit f2036734fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 18 deletions

View File

@ -1085,7 +1085,6 @@ def scaled_fp4_experts_quant(
blockscale_offsets: torch.Tensor,
topk: int,
expert_map: Optional[torch.Tensor] = None,
MAX_TOKENS_PER_EXPERT: int = 163840,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
@ -1107,9 +1106,16 @@ def scaled_fp4_experts_quant(
input_tensor = input_tensor[
expert_map] if expert_map is not None else input_tensor
m_numtopk, k = input_tensor.shape
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}")
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.")
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4

View File

@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
def get_default_cache_root():
@ -814,6 +815,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
# This is used to prevent the kernel from running out of memory.
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
}
# --8<-- [end:env-vars-definition]

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
""" CUTLASS based Fused MoE kernels."""
import os
from typing import Optional
import torch
@ -271,8 +270,6 @@ def cutlass_moe_fp8(
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
MAX_TOKENS_PER_EXPERT = int(
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
== m), ("topk must be provided for each row of a")
assert (m <= MAX_TOKENS_PER_EXPERT), (
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m = {m}. Use"
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
out_dtype = a.dtype
num_topk = topk_ids.shape[1]
@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
expert_offsets,
blockscale_offsets,
num_topk,
expert_map=a_map,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
expert_map=a_map)
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
torch.ops._C.silu_and_mul(intermediate, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
intermediate,
a2_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1],