mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 00:14:34 +08:00
[ModelOpt] Introduce VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE env var to control blockscale tensor allocation (#18160)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
parent
7d9216495c
commit
f2036734fb
@ -1085,7 +1085,6 @@ def scaled_fp4_experts_quant(
|
||||
blockscale_offsets: torch.Tensor,
|
||||
topk: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
MAX_TOKENS_PER_EXPERT: int = 163840,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale, for
|
||||
@ -1107,9 +1106,16 @@ def scaled_fp4_experts_quant(
|
||||
input_tensor = input_tensor[
|
||||
expert_map] if expert_map is not None else input_tensor
|
||||
m_numtopk, k = input_tensor.shape
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
|
||||
# from running out of memory. This value can also be increased to support
|
||||
# larger models.
|
||||
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
|
||||
assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), (
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for"
|
||||
f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}")
|
||||
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
|
||||
f"{MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
|
||||
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value.")
|
||||
scales_k = k // 16
|
||||
padded_k = (scales_k + (4 - 1)) // 4
|
||||
|
||||
|
||||
@ -117,6 +117,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
|
||||
VLLM_ALL2ALL_BACKEND: str = "naive"
|
||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -814,6 +815,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "pplx": use pplx kernels
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
|
||||
|
||||
# Control the maximum number of tokens per expert supported by the
|
||||
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
|
||||
# the blockscale tensor of activations NVFP4 Quantization.
|
||||
# This is used to prevent the kernel from running out of memory.
|
||||
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
|
||||
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
""" CUTLASS based Fused MoE kernels."""
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -271,8 +270,6 @@ def cutlass_moe_fp8(
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
MAX_TOKENS_PER_EXPERT = int(
|
||||
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (topk_weights.shape[0] == m and topk_ids.shape[0]
|
||||
== m), ("topk must be provided for each row of a")
|
||||
assert (m <= MAX_TOKENS_PER_EXPERT), (
|
||||
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
|
||||
f" for cutlass_moe_fp4, observed m = {m}. Use"
|
||||
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.shape[1]
|
||||
|
||||
@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
expert_map=a_map,
|
||||
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
|
||||
expert_map=a_map)
|
||||
|
||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
intermediate,
|
||||
a2_gscale,
|
||||
expert_offsets,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
|
||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||
|
||||
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user