mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 23:35:01 +08:00
[Core] FlashInfer CUTLASS fused MoE backend (NVFP4) (#20037)
Signed-off-by: shuw <shuw@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
b38baabcf9
commit
c7d8724e78
@ -956,11 +956,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
|||||||
c_strides, per_act_token, per_out_ch)
|
c_strides, per_act_token, per_out_ch)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
|
||||||
a_scales: torch.Tensor, b_scales: torch.Tensor,
|
b_tensors: torch.Tensor, a_scales: torch.Tensor,
|
||||||
alphas: torch.Tensor, problem_sizes: torch.Tensor,
|
b_scales: torch.Tensor, alphas: torch.Tensor,
|
||||||
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
|
problem_sizes: torch.Tensor,
|
||||||
out_dtype: torch.dtype, device: torch.device):
|
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
|
||||||
the gemms for each combination based on the specified problem sizes.
|
the gemms for each combination based on the specified problem sizes.
|
||||||
@ -977,14 +977,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
|
|||||||
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
|
||||||
MMs used in the fused MoE operation.
|
MMs used in the fused MoE operation.
|
||||||
"""
|
"""
|
||||||
m_topk = a_tensors.shape[0]
|
return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
|
||||||
n = b_tensors.shape[1]
|
a_scales, b_scales, alphas,
|
||||||
c_shape = (m_topk, n)
|
problem_sizes, expert_offsets,
|
||||||
c = torch.empty(c_shape, device=device, dtype=out_dtype)
|
sf_offsets)
|
||||||
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
|
|
||||||
b_scales, alphas, problem_sizes,
|
|
||||||
expert_offsets, sf_offsets)
|
|
||||||
return c.to(out_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# aqlm
|
# aqlm
|
||||||
|
|||||||
@ -119,6 +119,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||||
VLLM_USE_DEEP_GEMM: bool = False
|
VLLM_USE_DEEP_GEMM: bool = False
|
||||||
|
VLLM_USE_FLASHINFER_MOE: bool = False
|
||||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||||
@ -853,6 +854,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_DEEP_GEMM":
|
"VLLM_USE_DEEP_GEMM":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||||
|
|
||||||
|
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
|
||||||
|
"VLLM_USE_FLASHINFER_MOE":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
|
||||||
|
|
||||||
# Control the cache sized used by the xgrammar compiler. The default
|
# Control the cache sized used by the xgrammar compiler. The default
|
||||||
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
||||||
# It can be changed with this variable if needed for some reason.
|
# It can be changed with this variable if needed for some reason.
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -255,28 +255,18 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||||
return (workspace13, workspace2, output, a.dtype)
|
return (workspace13, workspace2, output, a.dtype)
|
||||||
|
|
||||||
def apply(
|
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||||
self,
|
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
output: torch.Tensor,
|
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||||
hidden_states: torch.Tensor,
|
expert_map: Optional[torch.Tensor],
|
||||||
w1: torch.Tensor,
|
w1_scale: Optional[torch.Tensor],
|
||||||
w2: torch.Tensor,
|
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||||
topk_ids: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||||
activation: str,
|
workspace2: torch.Tensor,
|
||||||
global_num_experts: int,
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
expert_map: Optional[torch.Tensor],
|
apply_router_weight_on_input: bool,
|
||||||
w1_scale: Optional[torch.Tensor],
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
|
||||||
workspace2: torch.Tensor,
|
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
):
|
|
||||||
assert expert_tokens_meta is not None
|
assert expert_tokens_meta is not None
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -142,7 +142,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool):
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
experts = (self.batched_deep_gemm_experts
|
experts = (self.batched_deep_gemm_experts
|
||||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||||
assert experts is not None
|
assert experts is not None
|
||||||
@ -150,4 +151,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation, global_num_experts, expert_map, w1_scale,
|
activation, global_num_experts, expert_map, w1_scale,
|
||||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||||
workspace2, expert_tokens_meta,
|
workspace2, expert_tokens_meta,
|
||||||
apply_router_weight_on_input)
|
apply_router_weight_on_input, extra_expert_args)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -188,6 +189,11 @@ class FusedMoEParallelConfig:
|
|||||||
return (self.use_all2all_kernels
|
return (self.use_all2all_kernels
|
||||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
|
return (envs.VLLM_USE_FLASHINFER_MOE
|
||||||
|
and has_flashinfer_cutlass_fused_moe())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(tp_size_: int, dp_size_: int,
|
def make(tp_size_: int, dp_size_: int,
|
||||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||||
@ -392,6 +398,10 @@ class FusedMoEConfig:
|
|||||||
def use_deepep_ll_kernels(self):
|
def use_deepep_ll_kernels(self):
|
||||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -435,6 +445,12 @@ class FusedMoEConfig:
|
|||||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||||
quant_dtype = torch.float8_e4m3fn
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.modelopt import (
|
||||||
|
ModelOptNvFp4Config)
|
||||||
|
if quant_dtype is None and isinstance(quant_config,
|
||||||
|
ModelOptNvFp4Config):
|
||||||
|
quant_dtype = torch.uint8
|
||||||
|
|
||||||
if weight_quant is not None:
|
if weight_quant is not None:
|
||||||
per_out_ch_quant = (
|
per_out_ch_quant = (
|
||||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
""" CUTLASS based Fused MoE kernels."""
|
""" CUTLASS based Fused MoE kernels."""
|
||||||
from typing import Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -14,7 +14,8 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||||
_resize_cache)
|
_resize_cache,
|
||||||
|
extract_required_args)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -298,7 +299,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool):
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
|
|
||||||
@ -431,23 +433,28 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
|||||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
|
||||||
|
|
||||||
def cutlass_moe_fp4(a: torch.Tensor,
|
def run_cutlass_moe_fp4(
|
||||||
a1_gscale: torch.Tensor,
|
output: torch.Tensor,
|
||||||
w1_fp4: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w1_blockscale: torch.Tensor,
|
a1_gscale: torch.Tensor,
|
||||||
w1_alphas: torch.Tensor,
|
w1_fp4: torch.Tensor,
|
||||||
a2_gscale: torch.Tensor,
|
w1_blockscale: torch.Tensor,
|
||||||
w2_fp4: torch.Tensor,
|
w1_alphas: torch.Tensor,
|
||||||
w2_blockscale: torch.Tensor,
|
a2_gscale: torch.Tensor,
|
||||||
w2_alphas: torch.Tensor,
|
w2_fp4: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
w2_blockscale: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
w2_alphas: torch.Tensor,
|
||||||
m: int,
|
topk_weights: torch.Tensor,
|
||||||
n: int,
|
topk_ids: torch.Tensor,
|
||||||
k: int,
|
workspace13: torch.Tensor,
|
||||||
e: int,
|
workspace2: torch.Tensor,
|
||||||
device: torch.device,
|
m: int,
|
||||||
apply_router_weight_on_input: bool = False):
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
device: torch.device,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
MoE implementation for FP4 Inputs
|
MoE implementation for FP4 Inputs
|
||||||
|
|
||||||
@ -487,16 +494,16 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
|||||||
|
|
||||||
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
||||||
" between weights.")
|
" between weights.")
|
||||||
assert (k_a // 2 == half_k_w1
|
assert (k_a == half_k_w1 * 2
|
||||||
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
||||||
assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in "
|
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
|
||||||
"expected `n`")
|
"expected `n`")
|
||||||
assert (m == m_a), "input shape mismatch"
|
assert (m == m_a), "input shape mismatch"
|
||||||
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
|
||||||
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
|
||||||
assert (topk_weights.size(0) == m and topk_ids.size(0)
|
assert (topk_weights.size(0) == m and topk_ids.size(0)
|
||||||
== m), ("topk must be provided for each row of a")
|
== m), ("topk must be provided for each row of a")
|
||||||
|
topk = topk_ids.size(1)
|
||||||
out_dtype = a.dtype
|
out_dtype = a.dtype
|
||||||
num_topk = topk_ids.size(1)
|
num_topk = topk_ids.size(1)
|
||||||
|
|
||||||
@ -523,7 +530,6 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
|||||||
blockscale_offsets)
|
blockscale_offsets)
|
||||||
|
|
||||||
a = ops.shuffle_rows(a, a_map)
|
a = ops.shuffle_rows(a, a_map)
|
||||||
|
|
||||||
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
|
rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
|
||||||
a,
|
a,
|
||||||
a1_gscale,
|
a1_gscale,
|
||||||
@ -531,34 +537,220 @@ def cutlass_moe_fp4(a: torch.Tensor,
|
|||||||
blockscale_offsets,
|
blockscale_offsets,
|
||||||
num_topk,
|
num_topk,
|
||||||
)
|
)
|
||||||
|
c1 = _resize_cache(workspace13, (m * topk, n * 2))
|
||||||
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
|
c2 = _resize_cache(workspace2, (m * topk, n))
|
||||||
w1_blockscale, w1_alphas, problem_sizes1,
|
c3 = _resize_cache(workspace13, (m * topk, k))
|
||||||
expert_offsets[:-1], blockscale_offsets[:-1],
|
ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale,
|
||||||
out_dtype, device)
|
w1_blockscale, w1_alphas, problem_sizes1,
|
||||||
|
expert_offsets[:-1], blockscale_offsets[:-1])
|
||||||
del rep_a_fp4, rep_a_blockscale
|
del rep_a_fp4, rep_a_blockscale
|
||||||
# hidden size dimension is split to one halfpytho sized tensor.
|
torch.ops._C.silu_and_mul(c2, c1)
|
||||||
intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2),
|
|
||||||
device=device,
|
|
||||||
dtype=out_dtype)
|
|
||||||
|
|
||||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
|
||||||
|
|
||||||
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
|
||||||
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
|
||||||
|
|
||||||
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale,
|
||||||
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
w2_alphas, problem_sizes2, expert_offsets[:-1],
|
||||||
blockscale_offsets[:-1], out_dtype, device)
|
blockscale_offsets[:-1])
|
||||||
del int_fp4, int_blockscale
|
del int_fp4, int_blockscale
|
||||||
|
|
||||||
c2 = ops.shuffle_rows(c2, c_map)
|
c3 = ops.shuffle_rows(c3, c_map)
|
||||||
|
|
||||||
|
assert output.dtype == out_dtype
|
||||||
if not apply_router_weight_on_input:
|
if not apply_router_weight_on_input:
|
||||||
out = (c2.view(m, num_topk, k) *
|
output.copy_(
|
||||||
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1)
|
(c3.view(m, num_topk, k) *
|
||||||
|
topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1),
|
||||||
|
non_blocking=True)
|
||||||
else:
|
else:
|
||||||
out = c2.view(m, num_topk, k).sum(dim=1)
|
output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
|
||||||
return out.to(dtype=out_dtype)
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_experts_per_worker: int,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
per_out_ch_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
use_batched_format: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig(
|
||||||
|
quant_dtype=torch.uint8,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
))
|
||||||
|
self.max_experts_per_worker = max_experts_per_worker
|
||||||
|
self.out_dtype = out_dtype
|
||||||
|
self.use_batched_format = use_batched_format
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
if self.use_batched_format:
|
||||||
|
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||||
|
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||||
|
else:
|
||||||
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||||
|
return TopKWeightAndReduceDelegate()
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
workspace1: tuple[int, ...] = ()
|
||||||
|
workspace2: tuple[int, ...] = ()
|
||||||
|
output: tuple[int, ...] = ()
|
||||||
|
if self.use_batched_format:
|
||||||
|
padded_M = aq.size(1)
|
||||||
|
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||||
|
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||||
|
output = (self.max_experts_per_worker, padded_M, K)
|
||||||
|
else:
|
||||||
|
workspace1 = (M * topk, max(2 * N, K))
|
||||||
|
workspace2 = (M * topk, N)
|
||||||
|
output = (M, K)
|
||||||
|
return (workspace1, workspace2, output,
|
||||||
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
|
|
||||||
|
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
|
||||||
|
workspace2: Optional[torch.Tensor],
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
|
required_keys = [
|
||||||
|
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
|
||||||
|
"e", "device"
|
||||||
|
]
|
||||||
|
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
|
||||||
|
device) = extract_required_args(extra_expert_args, required_keys)
|
||||||
|
run_cutlass_moe_fp4(
|
||||||
|
output=output,
|
||||||
|
a=hidden_states,
|
||||||
|
a1_gscale=a1_gscale,
|
||||||
|
w1_fp4=w1,
|
||||||
|
w1_blockscale=w1_scale,
|
||||||
|
w1_alphas=g1_alphas,
|
||||||
|
a2_gscale=a2_gscale,
|
||||||
|
w2_fp4=w2,
|
||||||
|
w2_blockscale=w2_scale,
|
||||||
|
w2_alphas=g2_alphas,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
workspace13=workspace13,
|
||||||
|
workspace2=workspace2,
|
||||||
|
m=m,
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
e=e,
|
||||||
|
device=device,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_moe_fp4(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_fp4: torch.Tensor,
|
||||||
|
w2_fp4: torch.Tensor,
|
||||||
|
w1_blockscale: torch.Tensor,
|
||||||
|
w2_blockscale: torch.Tensor,
|
||||||
|
g1_alphas: torch.Tensor,
|
||||||
|
g2_alphas: torch.Tensor,
|
||||||
|
a1_gscale: torch.Tensor,
|
||||||
|
a2_gscale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
device: torch.device,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False) -> torch.Tensor:
|
||||||
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
|
"is currently not supported for "
|
||||||
|
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
||||||
|
fn = mk.FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
CutlassExpertsFp4(
|
||||||
|
max_experts_per_worker=e,
|
||||||
|
out_dtype=a.dtype,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
per_out_ch_quant=False,
|
||||||
|
use_batched_format=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
extra_expert_args = {
|
||||||
|
'g1_alphas': g1_alphas,
|
||||||
|
'g2_alphas': g2_alphas,
|
||||||
|
'a1_gscale': a1_gscale,
|
||||||
|
'a2_gscale': a2_gscale,
|
||||||
|
'm': m,
|
||||||
|
'n': n,
|
||||||
|
'k': k,
|
||||||
|
'e': e,
|
||||||
|
'device': device,
|
||||||
|
}
|
||||||
|
|
||||||
|
# NVFP4 requires two levels of quantization, which involves computing some
|
||||||
|
# scaling factors dynamically. This makes it incompatible with the typical
|
||||||
|
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
|
||||||
|
# MoE body.
|
||||||
|
extra_prepare_args = {
|
||||||
|
'skip_quant': True,
|
||||||
|
}
|
||||||
|
# Similar reason as above.
|
||||||
|
extra_finalize_args = {
|
||||||
|
'skip_weight_reduce': True,
|
||||||
|
}
|
||||||
|
return fn(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1_fp4,
|
||||||
|
w2=w2_fp4,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation="silu",
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=None,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
extra_expert_args=extra_expert_args,
|
||||||
|
extra_prepare_args=extra_prepare_args,
|
||||||
|
extra_finalize_args=extra_finalize_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _valid_cutlass_block_scaled_grouped_gemm(
|
def _valid_cutlass_block_scaled_grouped_gemm(
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -152,6 +152,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]],
|
||||||
):
|
):
|
||||||
assert self.block_shape is not None
|
assert self.block_shape is not None
|
||||||
assert a1q_scale is not None
|
assert a1q_scale is not None
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import deep_ep
|
import deep_ep
|
||||||
import torch
|
import torch
|
||||||
@ -127,16 +127,12 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_topk_weights)
|
expert_topk_weights)
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||||
a1: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor, num_experts: int,
|
||||||
a2_scale: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]]
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -191,7 +187,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
|
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import deep_ep
|
import deep_ep
|
||||||
import torch
|
import torch
|
||||||
@ -111,16 +111,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return x, x_scales
|
return x, x_scales
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||||
a1: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor, num_experts: int,
|
||||||
a2_scale: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]]
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -169,7 +165,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||||
), ("Weight application and reduction happens in the combine kernel.")
|
), ("Weight application and reduction happens in the combine kernel.")
|
||||||
|
|||||||
198
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
Normal file
198
vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
|
TopKWeightAndReduceDelegate)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
||||||
|
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
|
||||||
|
has_flashinfer_cutlass_fused_moe)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given problem size is supported by the FlashInfer CUTLASS MoE
|
||||||
|
kernel.
|
||||||
|
"""
|
||||||
|
if not has_flashinfer_cutlass_fused_moe():
|
||||||
|
logger.debug_once("FlashInferExperts disabled: "
|
||||||
|
"flashinfer_cutlass_fused_moe not available.")
|
||||||
|
return False
|
||||||
|
# Data type checks
|
||||||
|
if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8
|
||||||
|
or hidden_states.dtype
|
||||||
|
not in [torch.float32, torch.float16, torch.bfloat16]):
|
||||||
|
logger.debug_once(
|
||||||
|
"FlashInferExperts disabled: w1/w2 must be torch.uint8 "
|
||||||
|
f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be "
|
||||||
|
f"float32, float16, or bfloat16 (got {hidden_states.dtype}).")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_nvfp4_w4a4: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_dp: bool = False,
|
||||||
|
ep_rank: int = 0,
|
||||||
|
ep_size: int = 1,
|
||||||
|
tp_rank: int = 0,
|
||||||
|
tp_size: int = 1,
|
||||||
|
num_dispatchers: Optional[int] = None,
|
||||||
|
use_batched_format: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
FusedMoEQuantConfig(
|
||||||
|
quant_dtype=torch.uint8,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=None,
|
||||||
|
))
|
||||||
|
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
|
||||||
|
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||||
|
self.ep_rank = ep_rank
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.tp_rank = tp_rank
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.use_dp = use_dp
|
||||||
|
assert not use_batched_format or num_dispatchers is not None
|
||||||
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_formats(
|
||||||
|
self
|
||||||
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_chunking(self) -> bool:
|
||||||
|
# This refers to TP chunking; DP chunking is handled separately.
|
||||||
|
return True
|
||||||
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||||
|
return TopKWeightAndReduceDelegate()
|
||||||
|
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
# We use global_num_experts due to how moe_align_block_size handles
|
||||||
|
# expert_maps.
|
||||||
|
"""
|
||||||
|
Compute the shapes for the temporary and final outputs of the two gemms
|
||||||
|
and activation in the fused expert function. Since the gemms are
|
||||||
|
independent, the workspace for the first gemm can be shared with the
|
||||||
|
workspace for the last gemm.
|
||||||
|
|
||||||
|
Returns a tuple of:
|
||||||
|
- workspace13 shape tuple: must be large enough to hold the
|
||||||
|
result of either expert gemm.
|
||||||
|
- workspace2 shape tuple: must be large enough to hold the
|
||||||
|
result of the activation function.
|
||||||
|
- output shape tuple: must be exact size of the final gemm output.
|
||||||
|
- Workspace type: The dtype to use for the workspace tensors.
|
||||||
|
- Note: in order for activation chunking to work, the first dimension
|
||||||
|
of each tuple must be the number of tokens.
|
||||||
|
"""
|
||||||
|
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
|
||||||
|
"currently supported.")
|
||||||
|
aq_m, aq_n = aq.shape
|
||||||
|
workspace2 = ()
|
||||||
|
output_shape = (aq_m, aq_n * 2)
|
||||||
|
workspace_dtype = a.dtype
|
||||||
|
workspace1 = output_shape
|
||||||
|
# The workspace is determined by `aq`, since it comes after any
|
||||||
|
# potential communication op and is involved in the expert computation.
|
||||||
|
return (workspace1, workspace2, output_shape, workspace_dtype)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor], # Not used
|
||||||
|
workspace13: Optional[torch.Tensor],
|
||||||
|
workspace2: Optional[torch.Tensor],
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: Optional[bool],
|
||||||
|
extra_expert_args: Optional[dict[str, Any]],
|
||||||
|
):
|
||||||
|
assert extra_expert_args is not None, \
|
||||||
|
"extra_expert_args must be provided"
|
||||||
|
required_keys = [
|
||||||
|
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
|
||||||
|
]
|
||||||
|
|
||||||
|
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
|
||||||
|
extract_required_args(extra_expert_args, required_keys))
|
||||||
|
|
||||||
|
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||||
|
# min because inv_scale.
|
||||||
|
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
|
||||||
|
"currently supported.")
|
||||||
|
|
||||||
|
# Ensure w1_scale and w2_scale are not None before calling view
|
||||||
|
assert w1_scale is not None and w2_scale is not None, (
|
||||||
|
"w1_scale and w2_scale must not "
|
||||||
|
"be None for FlashInferExperts")
|
||||||
|
|
||||||
|
assert not apply_router_weight_on_input
|
||||||
|
|
||||||
|
quant_scales = [
|
||||||
|
a1_gscale,
|
||||||
|
w1_scale.view(torch.int32),
|
||||||
|
g1_alphas,
|
||||||
|
a2_gscale,
|
||||||
|
w2_scale.view(torch.int32),
|
||||||
|
g2_alphas,
|
||||||
|
]
|
||||||
|
_ = flashinfer_cutlass_fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
topk_ids.to(torch.int),
|
||||||
|
topk_weights,
|
||||||
|
# FlashInfer API requires weight to be long for nvfp4
|
||||||
|
w1.view(torch.long),
|
||||||
|
w2.view(torch.long),
|
||||||
|
output_dtype=out_dtype,
|
||||||
|
quant_scales=quant_scales,
|
||||||
|
input_sf=a1q_scale,
|
||||||
|
tp_size=self.tp_size,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
ep_size=self.ep_size,
|
||||||
|
ep_rank=self.ep_rank,
|
||||||
|
output=output,
|
||||||
|
)
|
||||||
@ -0,0 +1,114 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.distributed import get_dp_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
extract_required_args, moe_kernel_quantize_input)
|
||||||
|
from vllm.utils.flashinfer import fp4_swizzle_blockscale
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_sizes(local_tokens):
|
||||||
|
cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu
|
||||||
|
sizes = [cu_sizes[0].item()]
|
||||||
|
for i in range(1, len(cu_sizes)):
|
||||||
|
sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item())
|
||||||
|
max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||||
|
sizes_chunked = [max_num_tokens] * len(sizes)
|
||||||
|
if local_tokens < max_num_tokens:
|
||||||
|
# When the number of local tokens is less than max_num_tokens, all other
|
||||||
|
# ranks will also have fewer than max_num_tokens. The remaining tokens
|
||||||
|
# are accounted for as residual.
|
||||||
|
sizes_chunked = [x % max_num_tokens for x in sizes]
|
||||||
|
|
||||||
|
return sizes_chunked
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_dtype: Optional[torch.dtype] = None,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
num_dispatchers: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.per_channel_quant = per_channel_quant
|
||||||
|
self.block_shape = block_shape
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
self.num_dispatchers_ = num_dispatchers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
|
return mk.FusedMoEActivationFormat.Standard
|
||||||
|
|
||||||
|
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def num_dispatchers(self) -> int:
|
||||||
|
return self.num_dispatchers_
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor], # Not used
|
||||||
|
a2_scale: Optional[torch.Tensor], # Not used
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]]
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
|
||||||
|
assert not apply_router_weight_on_input
|
||||||
|
|
||||||
|
(a1_gscale, use_dp, local_tokens) = extract_required_args(
|
||||||
|
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
|
||||||
|
|
||||||
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
|
a1,
|
||||||
|
a1_gscale,
|
||||||
|
quant_config.quant_dtype,
|
||||||
|
self.per_channel_quant,
|
||||||
|
self.block_shape,
|
||||||
|
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
|
||||||
|
)
|
||||||
|
if use_dp:
|
||||||
|
topk_weights, topk_ids, a1q, a1q_scale = \
|
||||||
|
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
|
||||||
|
dim=0,
|
||||||
|
sizes=get_local_sizes(local_tokens))
|
||||||
|
a1_m, a1_n = a1q.shape
|
||||||
|
a1q_scale = fp4_swizzle_blockscale(a1q_scale, a1_m, a1_n * 2)
|
||||||
|
|
||||||
|
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||||
|
|
||||||
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
|
|
||||||
|
(use_dp,
|
||||||
|
local_tokens) = extract_required_args(extra_finalize_args,
|
||||||
|
['use_dp', 'local_tokens'])
|
||||||
|
if use_dp:
|
||||||
|
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||||
|
fused_expert_output,
|
||||||
|
dim=0,
|
||||||
|
sizes=get_local_sizes(local_tokens),
|
||||||
|
)
|
||||||
|
output.copy_(fused_expert_output)
|
||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Fused batched MoE kernel."""
|
"""Fused batched MoE kernel."""
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -496,16 +496,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return self.num_dispatchers_
|
return self.num_dispatchers_
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||||
a1: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor, num_experts: int,
|
||||||
a2_scale: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]]
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -594,15 +590,11 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
return b_a1, b_a1_scale, expert_tokens_meta, None, None
|
return b_a1, b_a1_scale, expert_tokens_meta, None, None
|
||||||
|
|
||||||
def finalize(
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
self,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
output: torch.Tensor,
|
apply_router_weight_on_input: bool,
|
||||||
fused_expert_output: torch.Tensor,
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
topk_weights: torch.Tensor,
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
|
||||||
) -> None:
|
|
||||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||||
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
|
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
|
||||||
weight_and_reduce_impl.apply(
|
weight_and_reduce_impl.apply(
|
||||||
@ -706,7 +698,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool):
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
assert hidden_states.dim() == 3
|
assert hidden_states.dim() == 3
|
||||||
assert expert_tokens_meta is not None
|
assert expert_tokens_meta is not None
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
@ -911,7 +904,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool):
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if self.use_int4_w4a16:
|
if self.use_int4_w4a16:
|
||||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||||
|
|||||||
@ -1646,6 +1646,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]],
|
||||||
):
|
):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if self.use_int4_w4a16:
|
if self.use_int4_w4a16:
|
||||||
|
|||||||
@ -34,6 +34,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
@ -45,6 +46,9 @@ if current_platform.is_cuda_alike():
|
|||||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
if has_flashinfer():
|
||||||
|
from .flashinfer_cutlass_prepare_finalize import (
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||||
@ -99,6 +103,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||||
|
|
||||||
|
if moe.use_flashinfer_cutlass_kernels:
|
||||||
|
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
|
||||||
|
quant_dtype=moe.quant_dtype, )
|
||||||
if moe.use_pplx_kernels:
|
if moe.use_pplx_kernels:
|
||||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||||
moe.max_num_tokens,
|
moe.max_num_tokens,
|
||||||
@ -204,6 +211,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
f"{self.__class__.__name__} must select appropriate gemm "
|
f"{self.__class__.__name__} must select appropriate gemm "
|
||||||
"implementation based on the prepare_finalize")
|
"implementation based on the prepare_finalize")
|
||||||
|
|
||||||
|
def maybe_swap_experts_impl(
|
||||||
|
self,
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -744,12 +757,15 @@ class FusedMoE(torch.nn.Module):
|
|||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
|
if isinstance(self.quant_method, FusedMoEMethodBase):
|
||||||
|
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
|
||||||
|
|
||||||
# Chunked all2all staging tensor
|
# Chunked all2all staging tensor
|
||||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
|
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
|
||||||
self.batched_hidden_states = torch.zeros(
|
self.batched_hidden_states = torch.zeros(
|
||||||
(moe.max_num_tokens, self.hidden_size),
|
(moe.max_num_tokens, self.hidden_size),
|
||||||
dtype=moe.in_dtype,
|
dtype=moe.in_dtype,
|
||||||
@ -801,6 +817,10 @@ class FusedMoE(torch.nn.Module):
|
|||||||
def use_deepep_ll_kernels(self):
|
def use_deepep_ll_kernels(self):
|
||||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
|
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
|
||||||
|
|
||||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||||
param: torch.nn.Parameter,
|
param: torch.nn.Parameter,
|
||||||
loaded_weight: torch.Tensor,
|
loaded_weight: torch.Tensor,
|
||||||
@ -1402,9 +1422,9 @@ class FusedMoE(torch.nn.Module):
|
|||||||
final_hidden_states, non_blocking=True)
|
final_hidden_states, non_blocking=True)
|
||||||
|
|
||||||
ctx = get_forward_context()
|
ctx = get_forward_context()
|
||||||
|
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
|
||||||
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
||||||
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
||||||
|
|
||||||
num_tokens = full_hidden_states.size(0)
|
num_tokens = full_hidden_states.size(0)
|
||||||
for chunk_start_ in range(0, max_tokens_across_dp,
|
for chunk_start_ in range(0, max_tokens_across_dp,
|
||||||
moe_dp_chunk_size_per_rank):
|
moe_dp_chunk_size_per_rank):
|
||||||
@ -1424,13 +1444,20 @@ class FusedMoE(torch.nn.Module):
|
|||||||
def forward_impl(self, hidden_states: torch.Tensor,
|
def forward_impl(self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor):
|
router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||||
|
# only when data parallelism (DP) is enabled.
|
||||||
|
use_flashinfer_cutlass_kernels = (
|
||||||
|
self.dp_size > 1
|
||||||
|
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
|
or use_flashinfer_cutlass_kernels):
|
||||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||||
|
|
||||||
do_naive_dispatch_combine: bool = (
|
do_naive_dispatch_combine: bool = (
|
||||||
self.dp_size > 1
|
self.dp_size > 1
|
||||||
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
and not self.moe_parallel_config.use_deepep_ht_kernels
|
||||||
|
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
|
||||||
if do_naive_dispatch_combine:
|
if do_naive_dispatch_combine:
|
||||||
hidden_states, router_logits = get_ep_group().dispatch(
|
hidden_states, router_logits = get_ep_group().dispatch(
|
||||||
hidden_states, router_logits)
|
hidden_states, router_logits)
|
||||||
@ -1460,7 +1487,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if do_naive_dispatch_combine:
|
if do_naive_dispatch_combine:
|
||||||
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
||||||
|
|
||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
# Default set to False. (May have to add shared expert outputs.
|
# Default set to False. (May have to add shared expert outputs.
|
||||||
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Optional, final
|
from typing import Any, Optional, final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -150,16 +150,12 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||||
a1: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor, num_experts: int,
|
||||||
a2_scale: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]]
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -190,15 +186,11 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def finalize(
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
self,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
output: torch.Tensor,
|
apply_router_weight_on_input: bool,
|
||||||
fused_expert_output: torch.Tensor,
|
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||||
topk_weights: torch.Tensor,
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Perform any combine plus apply weights and perform a reduction on the
|
Perform any combine plus apply weights and perform a reduction on the
|
||||||
fused experts output.
|
fused experts output.
|
||||||
@ -376,6 +368,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This function computes the intermediate result of a Mixture of Experts
|
This function computes the intermediate result of a Mixture of Experts
|
||||||
@ -460,21 +453,19 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
f"{fused_experts.__class__.__name__}."
|
f"{fused_experts.__class__.__name__}."
|
||||||
f"{fused_experts.activation_formats[0]}")
|
f"{fused_experts.activation_formats[0]}")
|
||||||
|
|
||||||
def _do_fused_experts(self, fused_out: Optional[torch.Tensor],
|
def _do_fused_experts(
|
||||||
a1: torch.Tensor, a1q: torch.Tensor,
|
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||||
w1: torch.Tensor, w2: torch.Tensor,
|
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
activation: str, global_num_experts: int,
|
activation: str, global_num_experts: int, local_num_experts: int,
|
||||||
local_num_experts: int,
|
expert_map: Optional[torch.Tensor],
|
||||||
expert_map: Optional[torch.Tensor],
|
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||||
w2_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
w1_zp: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
w2_zp: Optional[torch.Tensor],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
a1q_scale: Optional[torch.Tensor],
|
apply_router_weight_on_input: bool,
|
||||||
a2_scale: Optional[torch.Tensor],
|
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
|
||||||
apply_router_weight_on_input: bool) -> torch.Tensor:
|
|
||||||
|
|
||||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||||
|
|
||||||
@ -517,7 +508,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
workspace13=workspace13,
|
workspace13=workspace13,
|
||||||
workspace2=workspace2,
|
workspace2=workspace2,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
extra_expert_args=extra_expert_args)
|
||||||
|
|
||||||
return fused_out
|
return fused_out
|
||||||
|
|
||||||
@ -541,6 +533,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
extra_expert_args: Optional[dict[str, Any]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||||
@ -568,7 +561,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1q_scale=a1q_scale,
|
a1q_scale=a1q_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
extra_expert_args=extra_expert_args)
|
||||||
|
|
||||||
# Chunking required case
|
# Chunking required case
|
||||||
assert num_chunks > 1
|
assert num_chunks > 1
|
||||||
@ -624,6 +618,15 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
expert_num_tokens=c_expert_num_tokens,
|
expert_num_tokens=c_expert_num_tokens,
|
||||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||||
|
|
||||||
|
m = None
|
||||||
|
if extra_expert_args is not None and 'm' in extra_expert_args:
|
||||||
|
m = extra_expert_args.get('m')
|
||||||
|
|
||||||
|
if extra_expert_args is not None:
|
||||||
|
chunked_extra_expert_args = extra_expert_args
|
||||||
|
else:
|
||||||
|
chunked_extra_expert_args = {}
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||||
slice_input_tensors(chunk_idx))
|
slice_input_tensors(chunk_idx))
|
||||||
@ -634,6 +637,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||||
expert_map)
|
expert_map)
|
||||||
|
|
||||||
|
s = chunk_idx * CHUNK_SIZE
|
||||||
|
e = min(s + CHUNK_SIZE, M)
|
||||||
|
|
||||||
|
if m is not None:
|
||||||
|
chunked_extra_expert_args['m'] = e - s
|
||||||
self._do_fused_experts(
|
self._do_fused_experts(
|
||||||
fused_out=slice_output_tensor(chunk_idx),
|
fused_out=slice_output_tensor(chunk_idx),
|
||||||
a1=a1,
|
a1=a1,
|
||||||
@ -653,7 +661,8 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1q_scale=c_a1q_scale,
|
a1q_scale=c_a1q_scale,
|
||||||
a2_scale=c_a2_scale,
|
a2_scale=c_a2_scale,
|
||||||
expert_tokens_meta=c_expert_tokens_meta,
|
expert_tokens_meta=c_expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
extra_expert_args=chunked_extra_expert_args)
|
||||||
|
|
||||||
return fused_out
|
return fused_out
|
||||||
|
|
||||||
@ -675,6 +684,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
|
extra_expert_args: Optional[dict] = None,
|
||||||
|
extra_prepare_args: Optional[dict] = None,
|
||||||
|
extra_finalize_args: Optional[dict] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||||
@ -707,6 +719,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
applied directly on the inputs. This is only applicable when topk is
|
applied directly on the inputs. This is only applicable when topk is
|
||||||
1.
|
1.
|
||||||
|
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
|
||||||
|
fused_experts.apply.
|
||||||
|
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
|
||||||
|
to prepare.
|
||||||
|
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
|
||||||
|
to finalize.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@ -730,6 +748,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
expert_map,
|
expert_map,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
self.fused_experts.quant_config,
|
self.fused_experts.quant_config,
|
||||||
|
extra_prepare_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||||
@ -766,11 +785,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1q_scale=a1q_scale,
|
a1q_scale=a1q_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
extra_expert_args=extra_expert_args)
|
||||||
|
|
||||||
self.prepare_finalize.finalize(
|
self.prepare_finalize.finalize(
|
||||||
output, fused_out, topk_weights, topk_ids,
|
output, fused_out, topk_weights, topk_ids,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
self.fused_experts.finalize_weight_and_reduce_impl())
|
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||||
|
extra_finalize_args)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pplx_kernels as pplx
|
import pplx_kernels as pplx
|
||||||
import torch
|
import torch
|
||||||
@ -89,16 +89,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return self.num_dispatchers_
|
return self.num_dispatchers_
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
||||||
a1: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor, num_experts: int,
|
||||||
a2_scale: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]]
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -217,15 +213,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||||
|
|
||||||
def finalize(
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
self,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
output: torch.Tensor,
|
apply_router_weight_on_input: bool,
|
||||||
fused_expert_output: torch.Tensor,
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
topk_weights: torch.Tensor,
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
|
||||||
) -> None:
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||||
), ("Weight application and reduction happens in the combine kernel.")
|
), ("Weight application and reduction happens in the combine kernel.")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -38,6 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
extra_prepare_args: Optional[dict[str, Any]],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -48,26 +49,33 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
assert topk == 1, \
|
assert topk == 1, \
|
||||||
"apply_router_weight_on_input is only implemented for topk=1"
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
a1.mul_(topk_weights.to(a1.dtype))
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
|
if (extra_prepare_args is not None
|
||||||
|
and extra_prepare_args.get("skip_quant", True)):
|
||||||
|
# Skip quantization if explicitly requested
|
||||||
|
return a1, None, None, None, None
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, a1_scale, quant_config.quant_dtype,
|
a1, a1_scale, quant_config.quant_dtype,
|
||||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
|
|
||||||
return a1q, a1q_scale, None, None, None
|
return a1q, a1q_scale, None, None, None
|
||||||
|
|
||||||
def finalize(
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
self,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
output: torch.Tensor,
|
apply_router_weight_on_input: bool,
|
||||||
fused_expert_output: torch.Tensor,
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
topk_weights: torch.Tensor,
|
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
||||||
topk_ids: torch.Tensor,
|
if (extra_finalize_args is not None
|
||||||
apply_router_weight_on_input: bool,
|
and extra_finalize_args.get("skip_weight_reduce", True)):
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
assert output.shape == fused_expert_output.shape
|
||||||
) -> None:
|
output.copy_(fused_expert_output)
|
||||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
else:
|
||||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||||
weight_and_reduce_impl.apply(
|
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||||
output=output,
|
weight_and_reduce_impl.apply(
|
||||||
fused_expert_output=fused_expert_output,
|
output=output,
|
||||||
topk_weights=topk_weights,
|
fused_expert_output=fused_expert_output,
|
||||||
topk_ids=topk_ids,
|
topk_weights=topk_weights,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
topk_ids=topk_ids,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -119,28 +119,18 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta)
|
expert_tokens_meta)
|
||||||
|
|
||||||
def apply(
|
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
||||||
self,
|
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
output: torch.Tensor,
|
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||||
hidden_states: torch.Tensor,
|
expert_map: Optional[torch.Tensor],
|
||||||
w1: torch.Tensor,
|
w1_scale: Optional[torch.Tensor],
|
||||||
w2: torch.Tensor,
|
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
||||||
topk_ids: torch.Tensor,
|
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
||||||
activation: str,
|
workspace2: torch.Tensor,
|
||||||
global_num_experts: int,
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
expert_map: Optional[torch.Tensor],
|
apply_router_weight_on_input: bool,
|
||||||
w1_scale: Optional[torch.Tensor],
|
extra_expert_args: Optional[dict[str, Any]]):
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
|
||||||
workspace2: torch.Tensor,
|
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
):
|
|
||||||
use_deep_gemm = (self.allow_deep_gemm
|
use_deep_gemm = (self.allow_deep_gemm
|
||||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||||
or is_blackwell_deep_gemm_used()))
|
or is_blackwell_deep_gemm_used()))
|
||||||
@ -168,4 +158,5 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2,
|
workspace2,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
|
extra_expert_args,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.utils.flashinfer import fp4_quantize
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -98,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
|||||||
return x.flatten()[:prod(v)].view(*v)
|
return x.flatten()[:prod(v)].view(*v)
|
||||||
|
|
||||||
|
|
||||||
|
def _fp4_quantize(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
is_sf_swizzled_layout: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return fp4_quantize(A,
|
||||||
|
A_scale,
|
||||||
|
is_sf_swizzled_layout=is_sf_swizzled_layout)
|
||||||
|
|
||||||
|
|
||||||
def _fp8_quantize(
|
def _fp8_quantize(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
@ -172,11 +183,16 @@ def moe_kernel_quantize_input(
|
|||||||
quant_dtype: Union[None, torch.dtype, str],
|
quant_dtype: Union[None, torch.dtype, str],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
is_fp4_scale_swizzled: bool = True,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
if quant_dtype == torch.float8_e4m3fn:
|
if quant_dtype == torch.float8_e4m3fn:
|
||||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
elif quant_dtype == torch.int8:
|
elif quant_dtype == torch.int8:
|
||||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
|
elif quant_dtype == torch.uint8: # nvfp4
|
||||||
|
return _fp4_quantize(A,
|
||||||
|
A_scale,
|
||||||
|
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||||
elif quant_dtype == "mxfp4":
|
elif quant_dtype == "mxfp4":
|
||||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
else:
|
else:
|
||||||
@ -236,3 +252,17 @@ def _validate_scale_shape(
|
|||||||
assert block_shape is not None
|
assert block_shape is not None
|
||||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_required_args(
|
||||||
|
extra_args: Optional[dict[str, Any]],
|
||||||
|
required_keys: list[str],
|
||||||
|
) -> tuple[Any, ...]:
|
||||||
|
if extra_args is None:
|
||||||
|
raise ValueError("`extra_args` must be provided.")
|
||||||
|
|
||||||
|
missing_keys = [k for k in required_keys if k not in extra_args]
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
|
||||||
|
|
||||||
|
return tuple(extra_args[k] for k in required_keys)
|
||||||
|
|||||||
@ -339,19 +339,19 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
return cutlass_moe_fp4(
|
return cutlass_moe_fp4(
|
||||||
a=x,
|
a=x,
|
||||||
w1_fp4=layer.w13_weight,
|
w1_fp4=layer.w13_weight,
|
||||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
|
||||||
w1_alphas=layer.g1_alphas,
|
|
||||||
w2_fp4=layer.w2_weight,
|
w2_fp4=layer.w2_weight,
|
||||||
|
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||||
w2_alphas=layer.g2_alphas,
|
g1_alphas=layer.g1_alphas,
|
||||||
|
g2_alphas=layer.g2_alphas,
|
||||||
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=layer.w2_input_scale_quant,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
m=x.shape[0],
|
m=x.shape[0],
|
||||||
n=layer.w2_weight.shape[2] * 2,
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
k=x.shape[1],
|
k=x.shape[1],
|
||||||
e=layer.w13_weight.shape[0],
|
e=layer.w13_weight.shape[0],
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
|
||||||
a2_gscale=layer.w2_input_scale_quant,
|
|
||||||
device=x.device,
|
device=x.device,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||||
x.dtype)
|
x.dtype)
|
||||||
|
|||||||
@ -7,9 +7,15 @@ import torch
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
||||||
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
@ -713,6 +719,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
self.allow_flashinfer_cutlass = False
|
||||||
|
|
||||||
|
if envs.VLLM_USE_FLASHINFER_MOE:
|
||||||
|
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
|
||||||
|
and current_platform.is_device_capability(100):
|
||||||
|
logger.info_once(
|
||||||
|
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
|
||||||
|
self.allow_flashinfer_cutlass = True
|
||||||
|
else:
|
||||||
|
logger.warning_once(
|
||||||
|
"Flashinfer CUTLASS Fused MoE not supported "
|
||||||
|
"or found on the current platform.")
|
||||||
|
|
||||||
if not self.cutlass_nvfp4_supported:
|
if not self.cutlass_nvfp4_supported:
|
||||||
if is_fp4_marlin_supported():
|
if is_fp4_marlin_supported():
|
||||||
@ -722,6 +740,73 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
" quantization. Please use Blackwell and"
|
" quantization. Please use Blackwell and"
|
||||||
" above.")
|
" above.")
|
||||||
|
|
||||||
|
self.fused_experts = None # type: ignore
|
||||||
|
|
||||||
|
def maybe_swap_experts_impl(
|
||||||
|
self,
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig,
|
||||||
|
):
|
||||||
|
if not self.allow_flashinfer_cutlass:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug_once("FlashInferExperts")
|
||||||
|
# default to TP/EP case only
|
||||||
|
|
||||||
|
experts_kwargs: dict[str, Any] = {
|
||||||
|
"use_nvfp4_w4a4": True,
|
||||||
|
"use_dp": moe_parallel_config.dp_size > 1,
|
||||||
|
"ep_rank": moe_parallel_config.ep_rank,
|
||||||
|
"ep_size": moe_parallel_config.ep_size,
|
||||||
|
"tp_rank": moe_parallel_config.tp_rank,
|
||||||
|
"tp_size": moe_parallel_config.tp_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||||
|
FlashInferExperts)
|
||||||
|
experts = FlashInferExperts(**experts_kwargs)
|
||||||
|
self.fused_experts = mk.FusedMoEModularKernel(
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize(
|
||||||
|
quant_dtype=torch.uint8,
|
||||||
|
#meaning 2x e2m1 packed in one, kernel requirement
|
||||||
|
),
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This method update self.fused_experts
|
||||||
|
# only prepare_finalize is not None call select_gemm_impl
|
||||||
|
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
|
||||||
|
# when it's not called(TP case), we still have 2 kernels to use.
|
||||||
|
def select_gemm_impl(self, prepare_finalize,
|
||||||
|
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
|
||||||
|
assert moe is not None
|
||||||
|
assert prepare_finalize is not None
|
||||||
|
experts = None
|
||||||
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
|
assert all2all_manager is not None
|
||||||
|
if self.allow_flashinfer_cutlass:
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||||
|
FlashInferExperts)
|
||||||
|
logger.debug_once("Using FlashInferExperts")
|
||||||
|
experts = FlashInferExperts(
|
||||||
|
use_nvfp4_w4a4=True,
|
||||||
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||||
|
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||||
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
|
tp_size=moe.moe_parallel_config.tp_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert moe.dp_size > 1
|
||||||
|
logger.debug_once("Using CutlassExpertsFp4")
|
||||||
|
# Currently CutlassExpertsFp4 doesn't support DP
|
||||||
|
raise ValueError(
|
||||||
|
"CutlassExpertsFp4 doesn't support DP. "
|
||||||
|
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
|
||||||
|
" backend instead.")
|
||||||
|
|
||||||
|
return experts
|
||||||
|
|
||||||
def uses_weight_scale_2_pattern(self) -> bool:
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
"""
|
"""
|
||||||
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
|
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
|
||||||
@ -842,8 +927,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
# GEMM 1
|
# GEMM 1
|
||||||
|
# The FlashInfer Cutlass fused MoE kernel expects the combined weights
|
||||||
|
# to be ordered as [w3, w1], unlike the standard [w1, w3] layout.
|
||||||
|
gemm1_weight = layer.w13_weight.data
|
||||||
|
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||||
|
|
||||||
|
if self.allow_flashinfer_cutlass:
|
||||||
|
dim = -2
|
||||||
|
size = gemm1_weight.size(dim)
|
||||||
|
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
||||||
|
half = size // 2
|
||||||
|
|
||||||
|
# Reorder weight
|
||||||
|
w1, w3 = gemm1_weight.split(half, dim=dim)
|
||||||
|
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
|
||||||
|
|
||||||
|
# Reorder scale
|
||||||
|
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
|
||||||
|
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
|
||||||
|
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
if not torch.allclose(layer.w13_weight_scale_2[:, 0],
|
||||||
layer.w13_weight_scale_2[:, 1]):
|
layer.w13_weight_scale_2[:, 1]):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -874,9 +981,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale_quant = Parameter(
|
layer.w13_input_scale_quant = Parameter(
|
||||||
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
(1 / w13_input_scale).to(torch.float32), requires_grad=False)
|
||||||
|
|
||||||
layer.w13_weight = Parameter(layer.w13_weight.data,
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
# GEMM 2
|
# GEMM 2
|
||||||
layer.g2_alphas = Parameter(
|
layer.g2_alphas = Parameter(
|
||||||
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
|
||||||
@ -961,31 +1065,74 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
if self.fused_experts is None:
|
||||||
"is currently not supported for "
|
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
|
||||||
"ModelOptNvFp4FusedMoE.")
|
# only (no EP).
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp4)
|
||||||
|
out = cutlass_moe_fp4(
|
||||||
|
a=x,
|
||||||
|
w1_fp4=layer.w13_weight,
|
||||||
|
w2_fp4=layer.w2_weight,
|
||||||
|
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||||
|
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||||
|
g1_alphas=layer.g1_alphas,
|
||||||
|
g2_alphas=layer.g2_alphas,
|
||||||
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=layer.w2_input_scale_quant,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
m=x.shape[0],
|
||||||
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
|
k=x.shape[1],
|
||||||
|
e=layer.w13_weight.shape[0],
|
||||||
|
device=x.device,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
else:
|
||||||
|
# TP or DP case
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||||
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
|
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||||
|
x, layer.w13_weight, layer.w2_weight), (
|
||||||
|
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
a1_gscale = torch.min(layer.w13_input_scale_quant)
|
||||||
cutlass_moe_fp4)
|
a2_gscale = torch.min(layer.w2_input_scale_quant)
|
||||||
|
extra_expert_args = {
|
||||||
|
'g1_alphas': layer.g1_alphas,
|
||||||
|
'g2_alphas': layer.g2_alphas,
|
||||||
|
'out_dtype': x.dtype,
|
||||||
|
# Avoid confusion with a1_scale and a2_scale
|
||||||
|
# where are batch size related.
|
||||||
|
'a1_gscale': a1_gscale,
|
||||||
|
'a2_gscale': a2_gscale,
|
||||||
|
}
|
||||||
|
extra_prepare_args = {
|
||||||
|
'use_dp': layer.dp_size > 1,
|
||||||
|
'local_tokens': x.shape[0],
|
||||||
|
'a1_gscale': a1_gscale,
|
||||||
|
}
|
||||||
|
extra_finalize_args = {
|
||||||
|
'use_dp': layer.dp_size > 1,
|
||||||
|
'local_tokens': x.shape[0],
|
||||||
|
}
|
||||||
|
|
||||||
# Cutlass moe takes in activations in BF16/Half precision
|
out = self.fused_experts(
|
||||||
# and fp4 quantized weights loaded from the checkpoint
|
hidden_states=x,
|
||||||
return cutlass_moe_fp4(
|
w1=layer.w13_weight,
|
||||||
a=x,
|
w2=layer.w2_weight,
|
||||||
w1_fp4=layer.w13_weight,
|
topk_weights=topk_weights,
|
||||||
w1_blockscale=layer.w13_blockscale_swizzled,
|
topk_ids=topk_ids,
|
||||||
w1_alphas=layer.g1_alphas,
|
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||||
w2_fp4=layer.w2_weight,
|
activation=activation,
|
||||||
w2_blockscale=layer.w2_blockscale_swizzled,
|
global_num_experts=global_num_experts,
|
||||||
w2_alphas=layer.g2_alphas,
|
expert_map=expert_map,
|
||||||
topk_weights=topk_weights,
|
w1_scale=layer.w13_blockscale_swizzled,
|
||||||
topk_ids=topk_ids,
|
w2_scale=layer.w2_blockscale_swizzled,
|
||||||
m=x.shape[0],
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
n=layer.w2_weight.shape[2] * 2,
|
extra_expert_args=extra_expert_args,
|
||||||
k=x.shape[1],
|
extra_prepare_args=extra_prepare_args,
|
||||||
e=layer.w13_weight.shape[0],
|
extra_finalize_args=extra_finalize_args,
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
)
|
||||||
a2_gscale=layer.w2_input_scale_quant,
|
return out
|
||||||
device=x.device,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
|
||||||
x.dtype)
|
|
||||||
|
|||||||
107
vllm/utils/flashinfer.py
Normal file
107
vllm/utils/flashinfer.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Compatibility wrapper for FlashInfer API changes.
|
||||||
|
|
||||||
|
Users of vLLM should always import **only** these wrappers.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
from typing import Any, Callable, NoReturn
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer() -> bool:
|
||||||
|
"""Return ``True`` if FlashInfer is available."""
|
||||||
|
# Use find_spec to check if the module exists without importing it
|
||||||
|
# This avoids potential CUDA initialization side effects
|
||||||
|
return importlib.util.find_spec("flashinfer") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||||
|
"""Placeholder for unavailable FlashInfer backend."""
|
||||||
|
raise RuntimeError(
|
||||||
|
"FlashInfer backend is not available. Please install the package "
|
||||||
|
"to enable FlashInfer kernels: "
|
||||||
|
"https://github.com/flashinfer-ai/flashinfer")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_submodule(module_name: str) -> Any | None:
|
||||||
|
"""Safely import a submodule and return it, or None if not available."""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(module_name)
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# General lazy import wrapper
|
||||||
|
def _lazy_import_wrapper(module_name: str,
|
||||||
|
attr_name: str,
|
||||||
|
fallback_fn: Callable[..., Any] = _missing):
|
||||||
|
"""Create a lazy import wrapper for a specific function."""
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def _get_impl():
|
||||||
|
if not has_flashinfer():
|
||||||
|
return None
|
||||||
|
mod = _get_submodule(module_name)
|
||||||
|
return getattr(mod, attr_name, None) if mod else None
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
impl = _get_impl()
|
||||||
|
if impl is None:
|
||||||
|
return fallback_fn(*args, **kwargs)
|
||||||
|
return impl(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
# Create lazy wrappers for each function
|
||||||
|
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
|
||||||
|
"cutlass_fused_moe")
|
||||||
|
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||||
|
fp4_swizzle_blockscale = _lazy_import_wrapper("flashinfer",
|
||||||
|
"fp4_swizzle_blockscale")
|
||||||
|
|
||||||
|
# Special case for autotune since it returns a context manager
|
||||||
|
autotune = _lazy_import_wrapper(
|
||||||
|
"flashinfer.autotuner",
|
||||||
|
"autotune",
|
||||||
|
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||||
|
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
|
||||||
|
if not has_flashinfer():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all required functions are available
|
||||||
|
required_functions = [
|
||||||
|
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||||
|
("flashinfer", "fp4_quantize"),
|
||||||
|
("flashinfer", "fp4_swizzle_blockscale"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for module_name, attr_name in required_functions:
|
||||||
|
mod = _get_submodule(module_name)
|
||||||
|
if not mod or not hasattr(mod, attr_name):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"has_flashinfer",
|
||||||
|
"has_flashinfer_cutlass_fused_moe",
|
||||||
|
"flashinfer_cutlass_fused_moe",
|
||||||
|
"fp4_quantize",
|
||||||
|
"fp4_swizzle_blockscale",
|
||||||
|
"autotune",
|
||||||
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user