mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:25:35 +08:00
[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)
Signed-off-by: shuw <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b38baabcf9
commit
c7d8724e78
@ -956,11 +956,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
c_strides, per_act_token, per_out_ch)
|
||||
|
||||
|
||||
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
||||
a_scales: torch.Tensor, b_scales: torch.Tensor,
|
||||
alphas: torch.Tensor, problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
|
||||
out_dtype: torch.dtype, device: torch.device):
|
||||
def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||
b_tensors: torch.Tensor, a_scales: torch.Tensor,
|
||||
b_scales: torch.Tensor, alphas: torch.Tensor,
|
||||
problem_sizes: torch.Tensor,
|
||||
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor):
|
||||
"""
|
||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||
the gemms for each combination based on the specified problem sizes.
|
||||
@ -977,14 +977,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||
MMs used in the fused MoE operation.
|
||||
"""
|
||||
m_topk = a_tensors.shape[0]
|
||||
n = b_tensors.shape[1]
|
||||
c_shape = (m_topk, n)
|
||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
||||
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
|
||||
b_scales, alphas, problem_sizes,
|
||||
expert_offsets, sf_offsets)
|
||||
return c.to(out_dtype)
|
||||
return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
|
||||
a_scales, b_scales, alphas,
|
||||
problem_sizes, expert_offsets,
|
||||
sf_offsets)
|
||||
|
||||
|
||||
# aqlm
|
||||
|
||||
@ -119,6 +119,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
@ -853,6 +854,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
|
||||
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
|
||||
"VLLM_USE_FLASHINFER_MOE":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
|
||||
|
||||
# Control the cache sized used by the xgrammar compiler. The default
|
||||
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
||||
# It can be changed with this variable if needed for some reason.
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -255,28 +255,18 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -142,7 +142,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
experts = (self.batched_deep_gemm_experts
|
||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||
assert experts is not None
|
||||
@ -150,4 +151,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation, global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||
workspace2, expert_tokens_meta,
|
||||
apply_router_weight_on_input)
|
||||
apply_router_weight_on_input, extra_expert_args)
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -188,6 +189,11 @@ class FusedMoEParallelConfig:
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE
|
||||
and has_flashinfer_cutlass_fused_moe())
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
@ -392,6 +398,10 @@ class FusedMoEConfig:
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_experts: int,
|
||||
@ -435,6 +445,12 @@ class FusedMoEConfig:
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
ModelOptNvFp4Config):
|
||||
quant_dtype = torch.uint8
|
||||
|
||||
if weight_quant is not None:
|
||||
per_out_ch_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
""" CUTLASS based Fused MoE kernels."""
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
_resize_cache,
|
||||
extract_required_args)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -298,7 +299,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
@ -431,7 +433,9 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
|
||||
def cutlass_moe_fp4(a: torch.Tensor,
|
||||
def run_cutlass_moe_fp4(
|
||||
output: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
@ -442,12 +446,15 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
w2_alphas: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
apply_router_weight_on_input: bool = False):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
|
||||
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
||||
" between weights.")
|
||||
assert (k_a // 2 == half_k_w1
|
||||
assert (k_a == half_k_w1 * 2
|
||||
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
||||
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
|
||||
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
|
||||
"expected `n`")
|
||||
assert (m == m_a), "input shape mismatch"
|
||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||
assert (topk_weights.size(0) == m and topk_ids.size(0)
|
||||
== m), ("topk must be provided for each row of a")
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
out_dtype = a.dtype
|
||||
num_topk = topk_ids.size(1)
|
||||
|
||||
@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
blockscale_offsets)
|
||||
|
||||
a = ops.shuffle_rows(a, a_map)
|
||||
|
||||
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
|
||||
a,
|
||||
a1_gscale,
|
||||
@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
||||
blockscale_offsets,
|
||||
num_topk,
|
||||
)
|
||||
|
||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||
ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||
w1_blockscale, w1_alphas, problem_sizes1,
|
||||
expert_offsets[:-1], blockscale_offsets[:-1],
|
||||
out_dtype, device)
|
||||
expert_offsets[:-1], blockscale_offsets[:-1])
|
||||
del rep_a_fp4, rep_a_blockscale
|
||||
# hidden size dimension is split to one halfpytho sized tensor.
|
||||
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
torch.ops._C.silu_and_mul(c2, c1)
|
||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||
|
||||
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||
blockscale_offsets[:-1], out_dtype, device)
|
||||
blockscale_offsets[:-1])
|
||||
del int_fp4, int_blockscale
|
||||
|
||||
c2 = ops.shuffle_rows(c2, c_map)
|
||||
c3 = ops.shuffle_rows(c3, c_map)
|
||||
|
||||
assert output.dtype == out_dtype
|
||||
if not apply_router_weight_on_input:
|
||||
out = (c2.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
|
||||
output.copy_(
|
||||
(c3.view(m, num_topk, k) *
|
||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1),
|
||||
non_blocking=True)
|
||||
else:
|
||||
out = c2.view(m, num_topk, k).sum(dim=1)
|
||||
return out.to(dtype=out_dtype)
|
||||
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
|
||||
return
|
||||
|
||||
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.uint8,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.use_batched_format:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
else:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.size(1)
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
required_keys = [
|
||||
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
|
||||
"e", "device"
|
||||
]
|
||||
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
|
||||
device) = extract_required_args(extra_expert_args, required_keys)
|
||||
run_cutlass_moe_fp4(
|
||||
output=output,
|
||||
a=hidden_states,
|
||||
a1_gscale=a1_gscale,
|
||||
w1_fp4=w1,
|
||||
w1_blockscale=w1_scale,
|
||||
w1_alphas=g1_alphas,
|
||||
a2_gscale=a2_gscale,
|
||||
w2_fp4=w2,
|
||||
w2_blockscale=w2_scale,
|
||||
w2_alphas=g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=e,
|
||||
device=device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
def cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
device: torch.device,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False) -> torch.Tensor:
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp4(
|
||||
max_experts_per_worker=e,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
extra_expert_args = {
|
||||
'g1_alphas': g1_alphas,
|
||||
'g2_alphas': g2_alphas,
|
||||
'a1_gscale': a1_gscale,
|
||||
'a2_gscale': a2_gscale,
|
||||
'm': m,
|
||||
'n': n,
|
||||
'k': k,
|
||||
'e': e,
|
||||
'device': device,
|
||||
}
|
||||
|
||||
# NVFP4 requires two levels of quantization, which involves computing some
|
||||
# scaling factors dynamically. This makes it incompatible with the typical
|
||||
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
|
||||
# MoE body.
|
||||
extra_prepare_args = {
|
||||
'skip_quant': True,
|
||||
}
|
||||
# Similar reason as above.
|
||||
extra_finalize_args = {
|
||||
'skip_weight_reduce': True,
|
||||
}
|
||||
return fn(
|
||||
hidden_states=a,
|
||||
w1=w1_fp4,
|
||||
w2=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
|
||||
|
||||
def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -152,6 +152,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
assert self.block_shape is not None
|
||||
assert a1q_scale is not None
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@ -127,16 +127,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_topk_weights)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@ -191,7 +187,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
|
||||
assert self.handle is not None
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import deep_ep
|
||||
import torch
|
||||
@ -111,16 +111,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return x, x_scales
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@ -169,7 +165,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
|
||||
198
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
Normal file
198
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
Normal file
@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
||||
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
|
||||
has_flashinfer_cutlass_fused_moe)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the FlashInfer CUTLASS MoE
|
||||
kernel.
|
||||
"""
|
||||
if not has_flashinfer_cutlass_fused_moe():
|
||||
logger.debug_once("FlashInferExperts disabled: "
|
||||
"flashinfer_cutlass_fused_moe not available.")
|
||||
return False
|
||||
# Data type checks
|
||||
if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8
|
||||
or hidden_states.dtype
|
||||
not in [torch.float32, torch.float16, torch.bfloat16]):
|
||||
logger.debug_once(
|
||||
"FlashInferExperts disabled: w1/w2 must be torch.uint8 "
|
||||
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
|
||||
f"float32, float16, or bfloat16 (got {hidden_states.dtype}).")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_nvfp4_w4a4: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_dp: bool = False,
|
||||
ep_rank: int = 0,
|
||||
ep_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
num_dispatchers: Optional[int] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.uint8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
))
|
||||
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.ep_rank = ep_rank
|
||||
self.ep_size = ep_size
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.use_dp = use_dp
|
||||
assert not use_batched_format or num_dispatchers is not None
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
# This refers to TP chunking; DP chunking is handled separately.
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
"""
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
|
||||
"currently supported.")
|
||||
aq_m, aq_n = aq.shape
|
||||
workspace2 = ()
|
||||
output_shape = (aq_m, aq_n * 2)
|
||||
workspace_dtype = a.dtype
|
||||
workspace1 = output_shape
|
||||
# The workspace is determined by `aq`, since it comes after any
|
||||
# potential communication op and is involved in the expert computation.
|
||||
return (workspace1, workspace2, output_shape, workspace_dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], # Not used
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: Optional[bool],
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
assert extra_expert_args is not None, \
|
||||
"extra_expert_args must be provided"
|
||||
required_keys = [
|
||||
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
|
||||
]
|
||||
|
||||
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
|
||||
extract_required_args(extra_expert_args, required_keys))
|
||||
|
||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||
# min because inv_scale.
|
||||
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
|
||||
"currently supported.")
|
||||
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert w1_scale is not None and w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not "
|
||||
"be None for FlashInferExperts")
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
quant_scales = [
|
||||
a1_gscale,
|
||||
w1_scale.view(torch.int32),
|
||||
g1_alphas,
|
||||
a2_gscale,
|
||||
w2_scale.view(torch.int32),
|
||||
g2_alphas,
|
||||
]
|
||||
_ = flashinfer_cutlass_fused_moe(
|
||||
hidden_states,
|
||||
topk_ids.to(torch.int),
|
||||
topk_weights,
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
w1.view(torch.long),
|
||||
w2.view(torch.long),
|
||||
output_dtype=out_dtype,
|
||||
quant_scales=quant_scales,
|
||||
input_sf=a1q_scale,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
output=output,
|
||||
)
|
||||
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
extract_required_args, moe_kernel_quantize_input)
|
||||
from vllm.utils.flashinfer import fp4_swizzle_blockscale
|
||||
|
||||
|
||||
def get_local_sizes(local_tokens):
|
||||
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
|
||||
sizes = [cu_sizes[0].item()]
|
||||
for i in range(1, len(cu_sizes)):
|
||||
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
|
||||
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
sizes_chunked = [max_num_tokens] * len(sizes)
|
||||
if local_tokens < max_num_tokens:
|
||||
# When the number of local tokens is less than max_num_tokens, all other
|
||||
# ranks will also have fewer than max_num_tokens. The remaining tokens
|
||||
# are accounted for as residual.
|
||||
sizes_chunked = [x % max_num_tokens for x in sizes]
|
||||
|
||||
return sizes_chunked
|
||||
|
||||
|
||||
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.block_shape = block_shape
|
||||
self.quant_dtype = quant_dtype
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
return None
|
||||
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor], # Not used
|
||||
a2_scale: Optional[torch.Tensor], # Not used
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
assert not apply_router_weight_on_input
|
||||
|
||||
(a1_gscale, use_dp, local_tokens) = extract_required_args(
|
||||
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
self.per_channel_quant,
|
||||
self.block_shape,
|
||||
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
|
||||
)
|
||||
if use_dp:
|
||||
topk_weights, topk_ids, a1q, a1q_scale = \
|
||||
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
|
||||
dim=0,
|
||||
sizes=get_local_sizes(local_tokens))
|
||||
a1_m, a1_n = a1q.shape
|
||||
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
|
||||
(use_dp,
|
||||
local_tokens) = extract_required_args(extra_finalize_args,
|
||||
['use_dp', 'local_tokens'])
|
||||
if use_dp:
|
||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||
fused_expert_output,
|
||||
dim=0,
|
||||
sizes=get_local_sizes(local_tokens),
|
||||
)
|
||||
output.copy_(fused_expert_output)
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused batched MoE kernel."""
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -496,16 +496,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return self.num_dispatchers_
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@ -594,15 +590,11 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
return b_a1, b_a1_scale, expert_tokens_meta, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
|
||||
weight_and_reduce_impl.apply(
|
||||
@ -706,7 +698,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
assert hidden_states.dim() == 3
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
@ -911,7 +904,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool):
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
|
||||
@ -1646,6 +1646,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
@ -45,6 +46,9 @@ if current_platform.is_cuda_alike():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
if has_flashinfer():
|
||||
from .flashinfer_cutlass_prepare_finalize import (
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||
@ -99,6 +103,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
|
||||
if moe.use_flashinfer_cutlass_kernels:
|
||||
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
|
||||
quant_dtype=moe.quant_dtype, )
|
||||
if moe.use_pplx_kernels:
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
@ -204,6 +211,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize")
|
||||
|
||||
def maybe_swap_experts_impl(
|
||||
self,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
@ -744,12 +757,15 @@ class FusedMoE(torch.nn.Module):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
if isinstance(self.quant_method, FusedMoEMethodBase):
|
||||
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
|
||||
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
@ -801,6 +817,10 @@ class FusedMoE(torch.nn.Module):
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
@ -1402,9 +1422,9 @@ class FusedMoE(torch.nn.Module):
|
||||
final_hidden_states, non_blocking=True)
|
||||
|
||||
ctx = get_forward_context()
|
||||
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||
|
||||
num_tokens = full_hidden_states.size(0)
|
||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||
moe_dp_chunk_size_per_rank):
|
||||
@ -1424,13 +1444,20 @@ class FusedMoE(torch.nn.Module):
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
use_flashinfer_cutlass_kernels = (
|
||||
self.dp_size > 1
|
||||
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or use_flashinfer_cutlass_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
self.dp_size > 1
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
||||
and not self.moe_parallel_config.use_deepep_ht_kernels
|
||||
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
@ -1460,7 +1487,6 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
# Default set to False. (May have to add shared expert outputs.
|
||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
|
||||
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional, final
|
||||
from typing import Any, Optional, final
|
||||
|
||||
import torch
|
||||
|
||||
@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@ -190,15 +186,11 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||
) -> None:
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Perform any combine plus apply weights and perform a reduction on the
|
||||
fused experts output.
|
||||
@ -376,6 +368,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
@ -460,21 +453,19 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
|
||||
a1: torch.Tensor, a1q: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
def _do_fused_experts(
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
activation: str, global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
activation: str, global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
@ -517,7 +508,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
return fused_out
|
||||
|
||||
@ -541,6 +533,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
extra_expert_args: Optional[dict[str, Any]],
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
@ -568,7 +561,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
@ -624,6 +618,15 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
m = None
|
||||
if extra_expert_args is not None and 'm' in extra_expert_args:
|
||||
m = extra_expert_args.get('m')
|
||||
|
||||
if extra_expert_args is not None:
|
||||
chunked_extra_expert_args = extra_expert_args
|
||||
else:
|
||||
chunked_extra_expert_args = {}
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
@ -634,6 +637,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
|
||||
if m is not None:
|
||||
chunked_extra_expert_args['m'] = e - s
|
||||
self._do_fused_experts(
|
||||
fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
@ -653,7 +661,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=chunked_extra_expert_args)
|
||||
|
||||
return fused_out
|
||||
|
||||
@ -675,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
extra_expert_args: Optional[dict] = None,
|
||||
extra_prepare_args: Optional[dict] = None,
|
||||
extra_finalize_args: Optional[dict] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||
@ -707,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
|
||||
fused_experts.apply.
|
||||
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to prepare.
|
||||
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
|
||||
to finalize.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
@ -730,6 +748,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
extra_prepare_args,
|
||||
)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
@ -766,11 +785,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args)
|
||||
|
||||
self.prepare_finalize.finalize(
|
||||
output, fused_out, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.finalize_weight_and_reduce_impl())
|
||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||
extra_finalize_args)
|
||||
|
||||
return output
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pplx_kernels as pplx
|
||||
import torch
|
||||
@ -89,16 +89,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return self.num_dispatchers_
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, num_experts: int,
|
||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]]
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@ -217,15 +213,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
assert isinstance(
|
||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||
), ("Weight application and reduction happens in the combine kernel.")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -38,6 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
extra_prepare_args: Optional[dict[str, Any]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
@ -48,21 +49,28 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
if (extra_prepare_args is not None
|
||||
and extra_prepare_args.get("skip_quant", True)):
|
||||
# Skip quantization if explicitly requested
|
||||
return a1, None, None, None, None
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, a1_scale, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||
if (extra_finalize_args is not None
|
||||
and extra_finalize_args.get("skip_weight_reduce", True)):
|
||||
assert output.shape == fused_expert_output.shape
|
||||
output.copy_(fused_expert_output)
|
||||
else:
|
||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||
weight_and_reduce_impl.apply(
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -119,28 +119,18 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
local_num_experts,
|
||||
expert_tokens_meta)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
extra_expert_args: Optional[dict[str, Any]]):
|
||||
use_deep_gemm = (self.allow_deep_gemm
|
||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||
or is_blackwell_deep_gemm_used()))
|
||||
@ -168,4 +158,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
apply_router_weight_on_input,
|
||||
extra_expert_args,
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import prod
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.flashinfer import fp4_quantize
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -98,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def _fp4_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
is_sf_swizzled_layout: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return fp4_quantize(A,
|
||||
A_scale,
|
||||
is_sf_swizzled_layout=is_sf_swizzled_layout)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
@ -172,11 +183,16 @@ def moe_kernel_quantize_input(
|
||||
quant_dtype: Union[None, torch.dtype, str],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
is_fp4_scale_swizzled: bool = True,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.uint8: # nvfp4
|
||||
return _fp4_quantize(A,
|
||||
A_scale,
|
||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||
elif quant_dtype == "mxfp4":
|
||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
else:
|
||||
@ -236,3 +252,17 @@ def _validate_scale_shape(
|
||||
assert block_shape is not None
|
||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||
|
||||
|
||||
def extract_required_args(
|
||||
extra_args: Optional[dict[str, Any]],
|
||||
required_keys: list[str],
|
||||
) -> tuple[Any, ...]:
|
||||
if extra_args is None:
|
||||
raise ValueError("`extra_args` must be provided.")
|
||||
|
||||
missing_keys = [k for k in required_keys if k not in extra_args]
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
|
||||
|
||||
return tuple(extra_args[k] for k in required_keys)
|
||||
|
||||
@ -339,19 +339,19 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||
x.dtype)
|
||||
|
||||
@ -7,9 +7,15 @@ import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||
from vllm.distributed import get_ep_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
@ -713,6 +719,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||
self.use_marlin = False
|
||||
self.allow_flashinfer_cutlass = False
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_MOE:
|
||||
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
|
||||
and current_platform.is_device_capability(100):
|
||||
logger.info_once(
|
||||
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
|
||||
self.allow_flashinfer_cutlass = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Flashinfer CUTLASS Fused MoE not supported "
|
||||
"or found on the current platform.")
|
||||
|
||||
if not self.cutlass_nvfp4_supported:
|
||||
if is_fp4_marlin_supported():
|
||||
@ -722,6 +740,73 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
" quantization. Please use Blackwell and"
|
||||
" above.")
|
||||
|
||||
self.fused_experts = None # type: ignore
|
||||
|
||||
def maybe_swap_experts_impl(
|
||||
self,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
):
|
||||
if not self.allow_flashinfer_cutlass:
|
||||
return
|
||||
|
||||
logger.debug_once("FlashInferExperts")
|
||||
# default to TP/EP case only
|
||||
|
||||
experts_kwargs: dict[str, Any] = {
|
||||
"use_nvfp4_w4a4": True,
|
||||
"use_dp": moe_parallel_config.dp_size > 1,
|
||||
"ep_rank": moe_parallel_config.ep_rank,
|
||||
"ep_size": moe_parallel_config.ep_size,
|
||||
"tp_rank": moe_parallel_config.tp_rank,
|
||||
"tp_size": moe_parallel_config.tp_size,
|
||||
}
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
experts = FlashInferExperts(**experts_kwargs)
|
||||
self.fused_experts = mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(
|
||||
quant_dtype=torch.uint8,
|
||||
#meaning 2x e2m1 packed in one, kernel requirement
|
||||
),
|
||||
experts,
|
||||
)
|
||||
|
||||
# This method update self.fused_experts
|
||||
# only prepare_finalize is not None call select_gemm_impl
|
||||
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
|
||||
# when it's not called(TP case), we still have 2 kernels to use.
|
||||
def select_gemm_impl(self, prepare_finalize,
|
||||
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
assert moe is not None
|
||||
assert prepare_finalize is not None
|
||||
experts = None
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
if self.allow_flashinfer_cutlass:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
logger.debug_once("Using FlashInferExperts")
|
||||
experts = FlashInferExperts(
|
||||
use_nvfp4_w4a4=True,
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
tp_size=moe.moe_parallel_config.tp_size,
|
||||
)
|
||||
else:
|
||||
assert moe.dp_size > 1
|
||||
logger.debug_once("Using CutlassExpertsFp4")
|
||||
# Currently CutlassExpertsFp4 doesn't support DP
|
||||
raise ValueError(
|
||||
"CutlassExpertsFp4 doesn't support DP. "
|
||||
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
|
||||
" backend instead.")
|
||||
|
||||
return experts
|
||||
|
||||
def uses_weight_scale_2_pattern(self) -> bool:
|
||||
"""
|
||||
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
|
||||
@ -842,8 +927,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
# GEMM 1
|
||||
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
||||
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
|
||||
gemm1_weight = layer.w13_weight.data
|
||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||
|
||||
if self.allow_flashinfer_cutlass:
|
||||
dim = -2
|
||||
size = gemm1_weight.size(dim)
|
||||
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
||||
half = size // 2
|
||||
|
||||
# Reorder weight
|
||||
w1, w3 = gemm1_weight.split(half, dim=dim)
|
||||
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
|
||||
|
||||
# Reorder scale
|
||||
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
|
||||
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
|
||||
|
||||
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
||||
requires_grad=False)
|
||||
|
||||
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
||||
layer.w13_weight_scale_2[:, 1]):
|
||||
logger.warning_once(
|
||||
@ -874,9 +981,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w13_input_scale_quant = Parameter(
|
||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||
|
||||
layer.w13_weight = Parameter(layer.w13_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
# GEMM 2
|
||||
layer.g2_alphas = Parameter(
|
||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||
@ -961,31 +1065,74 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE.")
|
||||
|
||||
if self.fused_experts is None:
|
||||
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||
# only (no EP).
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(
|
||||
out = cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w1_alphas=layer.g1_alphas,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||
w2_alphas=layer.g2_alphas,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
device=x.device,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||
x.dtype)
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
# TP or DP case
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
a1_gscale = torch.min(layer.w13_input_scale_quant)
|
||||
a2_gscale = torch.min(layer.w2_input_scale_quant)
|
||||
extra_expert_args = {
|
||||
'g1_alphas': layer.g1_alphas,
|
||||
'g2_alphas': layer.g2_alphas,
|
||||
'out_dtype': x.dtype,
|
||||
# Avoid confusion with a1_scale and a2_scale
|
||||
# where are batch size related.
|
||||
'a1_gscale': a1_gscale,
|
||||
'a2_gscale': a2_gscale,
|
||||
}
|
||||
extra_prepare_args = {
|
||||
'use_dp': layer.dp_size > 1,
|
||||
'local_tokens': x.shape[0],
|
||||
'a1_gscale': a1_gscale,
|
||||
}
|
||||
extra_finalize_args = {
|
||||
'use_dp': layer.dp_size > 1,
|
||||
'local_tokens': x.shape[0],
|
||||
}
|
||||
|
||||
out = self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_blockscale_swizzled,
|
||||
w2_scale=layer.w2_blockscale_swizzled,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
extra_expert_args=extra_expert_args,
|
||||
extra_prepare_args=extra_prepare_args,
|
||||
extra_finalize_args=extra_finalize_args,
|
||||
)
|
||||
return out
|
||||
|
||||
107
vllm/utils/flashinfer.py
Normal file
107
vllm/utils/flashinfer.py
Normal file
@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for FlashInfer API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
from typing import Any, Callable, NoReturn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer() -> bool:
|
||||
"""Return ``True`` if FlashInfer is available."""
|
||||
# Use find_spec to check if the module exists without importing it
|
||||
# This avoids potential CUDA initialization side effects
|
||||
return importlib.util.find_spec("flashinfer") is not None
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable FlashInfer backend."""
|
||||
raise RuntimeError(
|
||||
"FlashInfer backend is not available. Please install the package "
|
||||
"to enable FlashInfer kernels: "
|
||||
"https://github.com/flashinfer-ai/flashinfer")
|
||||
|
||||
|
||||
def _get_submodule(module_name: str) -> Any | None:
|
||||
"""Safely import a submodule and return it, or None if not available."""
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
# General lazy import wrapper
|
||||
def _lazy_import_wrapper(module_name: str,
|
||||
attr_name: str,
|
||||
fallback_fn: Callable[..., Any] = _missing):
|
||||
"""Create a lazy import wrapper for a specific function."""
|
||||
|
||||
@functools.cache
|
||||
def _get_impl():
|
||||
if not has_flashinfer():
|
||||
return None
|
||||
mod = _get_submodule(module_name)
|
||||
return getattr(mod, attr_name, None) if mod else None
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
impl = _get_impl()
|
||||
if impl is None:
|
||||
return fallback_fn(*args, **kwargs)
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Create lazy wrappers for each function
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||
"cutlass_fused_moe")
|
||||
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
|
||||
"fp4_swizzle_blockscale")
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
"flashinfer.autotuner",
|
||||
"autotune",
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||
("flashinfer", "fp4_quantize"),
|
||||
("flashinfer", "fp4_swizzle_blockscale"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"fp4_quantize",
|
||||
"fp4_swizzle_blockscale",
|
||||
"autotune",
|
||||
]
|
||||
Loading…
x
Reference in New Issue
Block a user