[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)

Signed-off-by: shuw <shuw@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang 2025-07-17 23:32:45 -05:00 committed by GitHub
parent b38baabcf9
commit c7d8724e78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1093 additions and 269 deletions

View File

@ -956,11 +956,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
c_strides, per_act_token, per_out_ch)
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
alphas: torch.Tensor, problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
out_dtype: torch.dtype, device: torch.device):
def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
b_tensors: torch.Tensor, a_scales: torch.Tensor,
b_scales: torch.Tensor, alphas: torch.Tensor,
problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
@ -977,14 +977,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_tensors.shape[0]
n = b_tensors.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
b_scales, alphas, problem_sizes,
expert_offsets, sf_offsets)
return c.to(out_dtype)
return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, alphas,
problem_sizes, expert_offsets,
sf_offsets)
# aqlm

View File

@ -119,6 +119,7 @@ if TYPE_CHECKING:
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_FLASHINFER_MOE: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@ -853,6 +854,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional
import torch
@ -255,28 +255,18 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional
import torch
@ -142,7 +142,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
@ -150,4 +151,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
apply_router_weight_on_input)
apply_router_weight_on_input, extra_expert_args)

View File

@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import cdiv
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__)
@ -188,6 +189,11 @@ class FusedMoEParallelConfig:
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE
and has_flashinfer_cutlass_fused_moe())
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
@ -392,6 +398,10 @@ class FusedMoEConfig:
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
@staticmethod
def make(
num_experts: int,
@ -435,6 +445,12 @@ class FusedMoEConfig:
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = torch.uint8
if weight_quant is not None:
per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
from typing import Callable, Optional
from typing import Any, Callable, Optional
import torch
@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
_resize_cache)
_resize_cache,
extract_required_args)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -298,7 +299,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
@ -431,23 +433,28 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cutlass_moe_fp4(a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
apply_router_weight_on_input: bool = False):
def run_cutlass_moe_fp4(
output: torch.Tensor,
a: torch.Tensor,
a1_gscale: torch.Tensor,
w1_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w1_alphas: torch.Tensor,
a2_gscale: torch.Tensor,
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
apply_router_weight_on_input: bool = False,
) -> None:
"""
MoE implementation for FP4 Inputs
@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor,
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.")
assert (k_a // 2 == half_k_w1
assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
"expected `n`")
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
"expected `n`")
assert (m == m_a), "input shape mismatch"
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.size(0) == m and topk_ids.size(0)
== m), ("topk must be provided for each row of a")
topk = topk_ids.size(1)
out_dtype = a.dtype
num_topk = topk_ids.size(1)
@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor,
blockscale_offsets)
a = ops.shuffle_rows(a, a_map)
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
a,
a1_gscale,
@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor,
blockscale_offsets,
num_topk,
)
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
expert_offsets[:-1], blockscale_offsets[:-1],
out_dtype, device)
c1 = _resize_cache(workspace13, (m * topk, n * 2))
c2 = _resize_cache(workspace2, (m * topk, n))
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1,
expert_offsets[:-1], blockscale_offsets[:-1])
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
device=device,
dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
torch.ops._C.silu_and_mul(c2, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1], out_dtype, device)
ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1],
blockscale_offsets[:-1])
del int_fp4, int_blockscale
c2 = ops.shuffle_rows(c2, c_map)
c3 = ops.shuffle_rows(c3, c_map)
assert output.dtype == out_dtype
if not apply_router_weight_on_input:
out = (c2.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
output.copy_(
(c3.view(m, num_topk, k) *
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1),
non_blocking=True)
else:
out = c2.view(m, num_topk, k).sum(dim=1)
return out.to(dtype=out_dtype)
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
return
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_expert_map(self) -> bool:
return False
def supports_chunking(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
required_keys = [
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
"e", "device"
]
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
device) = extract_required_args(extra_expert_args, required_keys)
run_cutlass_moe_fp4(
output=output,
a=hidden_states,
a1_gscale=a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_alphas=g1_alphas,
a2_gscale=a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_alphas=g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace13=workspace13,
workspace2=workspace2,
m=m,
n=n,
k=k,
e=e,
device=device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def cutlass_moe_fp4(
a: torch.Tensor,
w1_fp4: torch.Tensor,
w2_fp4: torch.Tensor,
w1_blockscale: torch.Tensor,
w2_blockscale: torch.Tensor,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
per_out_ch_quant=False,
use_batched_format=False,
),
)
extra_expert_args = {
'g1_alphas': g1_alphas,
'g2_alphas': g2_alphas,
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
'm': m,
'n': n,
'k': k,
'e': e,
'device': device,
}
# NVFP4 requires two levels of quantization, which involves computing some
# scaling factors dynamically. This makes it incompatible with the typical
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
# MoE body.
extra_prepare_args = {
'skip_quant': True,
}
# Similar reason as above.
extra_finalize_args = {
'skip_weight_reduce': True,
}
return fn(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
global_num_experts=e,
expert_map=None,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
def _valid_cutlass_block_scaled_grouped_gemm(

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Optional
from typing import Any, Optional
import torch
@ -152,6 +152,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
assert self.block_shape is not None
assert a1q_scale is not None

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional
import deep_ep
import torch
@ -127,16 +127,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -191,7 +187,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
assert self.handle is not None

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import Any, Optional, Union
import deep_ep
import torch
@ -111,16 +111,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -169,7 +165,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")

View File

@ -0,0 +1,198 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe)
logger = init_logger(__name__)
def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor) -> bool:
"""
Check if the given problem size is supported by the FlashInfer CUTLASS MoE
kernel.
"""
if not has_flashinfer_cutlass_fused_moe():
logger.debug_once("FlashInferExperts disabled: "
"flashinfer_cutlass_fused_moe not available.")
return False
# Data type checks
if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8
or hidden_states.dtype
not in [torch.float32, torch.float16, torch.bfloat16]):
logger.debug_once(
"FlashInferExperts disabled: w1/w2 must be torch.uint8 "
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
f"float32, float16, or bfloat16 (got {hidden_states.dtype}).")
return False
return True
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_nvfp4_w4a4: bool = False,
use_fp8_w8a8: bool = False,
use_dp: bool = False,
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
per_act_token_quant=False,
block_shape=None,
))
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
self.use_fp8_w8a8 = use_fp8_w8a8
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.use_dp = use_dp
assert not use_batched_format or num_dispatchers is not None
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_expert_map(self) -> bool:
return False
def supports_chunking(self) -> bool:
# This refers to TP chunking; DP chunking is handled separately.
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
"""
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
aq_m, aq_n = aq.shape
workspace2 = ()
output_shape = (aq_m, aq_n * 2)
workspace_dtype = a.dtype
workspace1 = output_shape
# The workspace is determined by `aq`, since it comes after any
# potential communication op and is involved in the expert computation.
return (workspace1, workspace2, output_shape, workspace_dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], # Not used
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
extra_expert_args: Optional[dict[str, Any]],
):
assert extra_expert_args is not None, \
"extra_expert_args must be provided"
required_keys = [
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
]
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
extract_required_args(extra_expert_args, required_keys))
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
"w1_scale and w2_scale must not "
"be None for FlashInferExperts")
assert not apply_router_weight_on_input
quant_scales = [
a1_gscale,
w1_scale.view(torch.int32),
g1_alphas,
a2_gscale,
w2_scale.view(torch.int32),
g2_alphas,
]
_ = flashinfer_cutlass_fused_moe(
hidden_states,
topk_ids.to(torch.int),
topk_weights,
# FlashInfer API requires weight to be long for nvfp4
w1.view(torch.long),
w2.view(torch.long),
output_dtype=out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
output=output,
)

View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input)
from vllm.utils.flashinfer import fp4_swizzle_blockscale
def get_local_sizes(local_tokens):
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
sizes = [cu_sizes[0].item()]
for i in range(1, len(cu_sizes)):
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
sizes_chunked = [max_num_tokens] * len(sizes)
if local_tokens < max_num_tokens:
# When the number of local tokens is less than max_num_tokens, all other
# ranks will also have fewer than max_num_tokens. The remaining tokens
# are accounted for as residual.
sizes_chunked = [x % max_num_tokens for x in sizes]
return sizes_chunked
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
num_dispatchers: int = 1,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
self.num_dispatchers_ = num_dispatchers
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return None
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], # Not used
a2_scale: Optional[torch.Tensor], # Not used
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
assert not apply_router_weight_on_input
(a1_gscale, use_dp, local_tokens) = extract_required_args(
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_gscale,
quant_config.quant_dtype,
self.per_channel_quant,
self.block_shape,
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
)
if use_dp:
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
dim=0,
sizes=get_local_sizes(local_tokens))
a1_m, a1_n = a1q.shape
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
(use_dp,
local_tokens) = extract_required_args(extra_finalize_args,
['use_dp', 'local_tokens'])
if use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output,
dim=0,
sizes=get_local_sizes(local_tokens),
)
output.copy_(fused_expert_output)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel."""
from typing import Optional
from typing import Any, Optional
import torch
@ -496,16 +496,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -594,15 +590,11 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply(
@ -706,7 +698,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert hidden_states.dim() == 3
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
@ -911,7 +904,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (

View File

@ -1646,6 +1646,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
# Check constraints.
if self.use_int4_w4a16:

View File

@ -34,6 +34,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@ -45,6 +46,9 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize)
if has_flashinfer():
from .flashinfer_cutlass_prepare_finalize import (
FlashInferCutlassMoEPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -99,6 +103,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_flashinfer_cutlass_kernels:
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=moe.quant_dtype, )
if moe.use_pplx_kernels:
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
@ -204,6 +211,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
pass
@abstractmethod
def apply(
self,
@ -744,12 +757,15 @@ class FusedMoE(torch.nn.Module):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
if isinstance(self.quant_method, FusedMoEMethodBase):
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
@ -801,6 +817,10 @@ class FusedMoE(torch.nn.Module):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
@ -1402,9 +1422,9 @@ class FusedMoE(torch.nn.Module):
final_hidden_states, non_blocking=True)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank):
@ -1424,13 +1444,20 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
@ -1460,7 +1487,6 @@ class FusedMoE(torch.nn.Module):
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
from typing import Any, Optional, final
import torch
@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -190,15 +186,11 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
@ -376,6 +368,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
"""
This function computes the intermediate result of a Mixture of Experts
@ -460,21 +453,19 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
a1: torch.Tensor, a1q: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool) -> torch.Tensor:
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -517,7 +508,8 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace13=workspace13,
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
return fused_out
@ -541,6 +533,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -568,7 +561,8 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
# Chunking required case
assert num_chunks > 1
@ -624,6 +618,15 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
@ -634,6 +637,11 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
@ -653,7 +661,8 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=c_a1q_scale,
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
return fused_out
@ -675,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@ -707,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@ -730,6 +748,7 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
extra_prepare_args,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
@ -766,11 +785,13 @@ class FusedMoEModularKernel(torch.nn.Module):
a1q_scale=a1q_scale,
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl())
self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
return output

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional
import pplx_kernels as pplx
import torch
@ -89,16 +89,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -217,15 +213,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional
import torch
@ -38,6 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -48,26 +49,33 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
if (extra_prepare_args is not None
and extra_prepare_args.get("skip_quant", True)):
# Skip quantization if explicitly requested
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
if (extra_finalize_args is not None
and extra_finalize_args.get("skip_weight_reduce", True)):
assert output.shape == fused_expert_output.shape
output.copy_(fused_expert_output)
else:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(
output=output,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional
import torch
@ -119,28 +119,18 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts,
expert_tokens_meta)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_used()))
@ -168,4 +158,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
extra_expert_args,
)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Optional, Union
from typing import Any, Optional, Union
import torch
@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
from vllm.utils.flashinfer import fp4_quantize
@triton.jit
@ -98,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
return x.flatten()[:prod(v)].view(*v)
def _fp4_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_sf_swizzled_layout)
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
@ -172,11 +183,16 @@ def moe_kernel_quantize_input(
quant_dtype: Union[None, torch.dtype, str],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
is_fp4_scale_swizzled: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.uint8: # nvfp4
return _fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
@ -236,3 +252,17 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def extract_required_args(
extra_args: Optional[dict[str, Any]],
required_keys: list[str],
) -> tuple[Any, ...]:
if extra_args is None:
raise ValueError("`extra_args` must be provided.")
missing_keys = [k for k in required_keys if k not in extra_args]
if missing_keys:
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
return tuple(extra_args[k] for k in required_keys)

View File

@ -339,19 +339,19 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)

View File

@ -7,9 +7,15 @@ import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.distributed import get_ep_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
@ -713,6 +719,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
self.use_marlin = False
self.allow_flashinfer_cutlass = False
if envs.VLLM_USE_FLASHINFER_MOE:
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
and current_platform.is_device_capability(100):
logger.info_once(
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
self.allow_flashinfer_cutlass = True
else:
logger.warning_once(
"Flashinfer CUTLASS Fused MoE not supported "
"or found on the current platform.")
if not self.cutlass_nvfp4_supported:
if is_fp4_marlin_supported():
@ -722,6 +740,73 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above.")
self.fused_experts = None # type: ignore
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
if not self.allow_flashinfer_cutlass:
return
logger.debug_once("FlashInferExperts")
# default to TP/EP case only
experts_kwargs: dict[str, Any] = {
"use_nvfp4_w4a4": True,
"use_dp": moe_parallel_config.dp_size > 1,
"ep_rank": moe_parallel_config.ep_rank,
"ep_size": moe_parallel_config.ep_size,
"tp_rank": moe_parallel_config.tp_rank,
"tp_size": moe_parallel_config.tp_size,
}
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
experts = FlashInferExperts(**experts_kwargs)
self.fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=torch.uint8,
#meaning 2x e2m1 packed in one, kernel requirement
),
experts,
)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
# when it's not called(TP case), we still have 2 kernels to use.
def select_gemm_impl(self, prepare_finalize,
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
assert moe is not None
assert prepare_finalize is not None
experts = None
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if self.allow_flashinfer_cutlass:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
logger.debug_once("Using FlashInferExperts")
experts = FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,
tp_size=moe.moe_parallel_config.tp_size,
)
else:
assert moe.dp_size > 1
logger.debug_once("Using CutlassExpertsFp4")
# Currently CutlassExpertsFp4 doesn't support DP
raise ValueError(
"CutlassExpertsFp4 doesn't support DP. "
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
" backend instead.")
return experts
def uses_weight_scale_2_pattern(self) -> bool:
"""
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
@ -842,8 +927,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
gemm1_weight = layer.w13_weight.data
gemm1_weight_scale = layer.w13_weight_scale.data
if self.allow_flashinfer_cutlass:
dim = -2
size = gemm1_weight.size(dim)
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
half = size // 2
# Reorder weight
w1, w3 = gemm1_weight.split(half, dim=dim)
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
# Reorder scale
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
requires_grad=False)
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
layer.w13_weight_scale_2[:, 1]):
logger.warning_once(
@ -874,9 +981,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
layer.w13_weight = Parameter(layer.w13_weight.data,
requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
@ -961,31 +1065,74 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts=global_num_experts,
expert_map=expert_map)
assert expert_map is None, ("Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE.")
if self.fused_experts is None:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
out = cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w2_fp4=layer.w2_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w2_blockscale=layer.w2_blockscale_swizzled,
g1_alphas=layer.g1_alphas,
g2_alphas=layer.g2_alphas,
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
device=x.device,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
# TP or DP case
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
is_valid_flashinfer_cutlass_fused_moe)
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4)
a1_gscale = torch.min(layer.w13_input_scale_quant)
a2_gscale = torch.min(layer.w2_input_scale_quant)
extra_expert_args = {
'g1_alphas': layer.g1_alphas,
'g2_alphas': layer.g2_alphas,
'out_dtype': x.dtype,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
}
extra_prepare_args = {
'use_dp': layer.dp_size > 1,
'local_tokens': x.shape[0],
'a1_gscale': a1_gscale,
}
extra_finalize_args = {
'use_dp': layer.dp_size > 1,
'local_tokens': x.shape[0],
}
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4(
a=x,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=x.shape[0],
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
a1_gscale=layer.w13_input_scale_quant,
a2_gscale=layer.w2_input_scale_quant,
device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
return out

107
vllm/utils/flashinfer.py Normal file
View File

@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.
Users of vLLM should always import **only** these wrappers.
"""
from __future__ import annotations
import contextlib
import functools
import importlib
import importlib.util
from typing import Any, Callable, NoReturn
from vllm.logger import init_logger
logger = init_logger(__name__)
@functools.cache
def has_flashinfer() -> bool:
"""Return ``True`` if FlashInfer is available."""
# Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects
return importlib.util.find_spec("flashinfer") is not None
def _missing(*_: Any, **__: Any) -> NoReturn:
"""Placeholder for unavailable FlashInfer backend."""
raise RuntimeError(
"FlashInfer backend is not available. Please install the package "
"to enable FlashInfer kernels: "
"https://github.com/flashinfer-ai/flashinfer")
def _get_submodule(module_name: str) -> Any | None:
"""Safely import a submodule and return it, or None if not available."""
try:
return importlib.import_module(module_name)
except (ImportError, ModuleNotFoundError):
return None
# General lazy import wrapper
def _lazy_import_wrapper(module_name: str,
attr_name: str,
fallback_fn: Callable[..., Any] = _missing):
"""Create a lazy import wrapper for a specific function."""
@functools.cache
def _get_impl():
if not has_flashinfer():
return None
mod = _get_submodule(module_name)
return getattr(mod, attr_name, None) if mod else None
def wrapper(*args, **kwargs):
impl = _get_impl()
if impl is None:
return fallback_fn(*args, **kwargs)
return impl(*args, **kwargs)
return wrapper
# Create lazy wrappers for each function
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
"fp4_swizzle_blockscale")
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
"flashinfer.autotuner",
"autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
if not has_flashinfer():
return False
# Check if all required functions are available
required_functions = [
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "fp4_swizzle_blockscale"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
__all__ = [
"has_flashinfer",
"has_flashinfer_cutlass_fused_moe",
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"fp4_swizzle_blockscale",
"autotune",
]