mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 21:15:46 +08:00
[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
e6585ddb45
commit
5963b98b46
@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config,
|
||||||
|
nvfp4_moe_quant_config,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
@ -140,6 +144,12 @@ def bench_run(
|
|||||||
a_fp8_scale: torch.Tensor,
|
a_fp8_scale: torch.Tensor,
|
||||||
num_repeats: int,
|
num_repeats: int,
|
||||||
):
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_fp8_scale,
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
fused_experts(
|
fused_experts(
|
||||||
a,
|
a,
|
||||||
@ -147,10 +157,7 @@ def bench_run(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a_fp8_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cutlass_moe_fp4(
|
def run_cutlass_moe_fp4(
|
||||||
@ -172,25 +179,27 @@ def bench_run(
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
num_repeats: int,
|
num_repeats: int,
|
||||||
):
|
):
|
||||||
|
quant_config = nvfp4_moe_quant_config(
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
g1_alphas=w1_gs,
|
||||||
|
g2_alphas=w2_gs,
|
||||||
|
)
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||||
cutlass_moe_fp4(
|
cutlass_moe_fp4(
|
||||||
a=a,
|
a=a,
|
||||||
a1_gscale=a1_gs,
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w1_fp4=w1_fp4,
|
w1_fp4=w1_fp4,
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
w1_alphas=w1_gs,
|
|
||||||
w2_fp4=w2_fp4,
|
w2_fp4=w2_fp4,
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
w2_alphas=w2_gs,
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
m=m,
|
m=m,
|
||||||
n=n,
|
n=n,
|
||||||
k=k,
|
k=k,
|
||||||
e=num_experts,
|
e=num_experts,
|
||||||
device=device,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cutlass_from_graph(
|
def run_cutlass_from_graph(
|
||||||
@ -211,26 +220,29 @@ def bench_run(
|
|||||||
e: int,
|
e: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
|
quant_config = nvfp4_moe_quant_config(
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
g1_alphas=w1_gs,
|
||||||
|
g2_alphas=w2_gs,
|
||||||
|
)
|
||||||
|
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
):
|
):
|
||||||
return cutlass_moe_fp4(
|
return cutlass_moe_fp4(
|
||||||
a=a,
|
a=a,
|
||||||
a1_gscale=a1_gs,
|
|
||||||
w1_fp4=w1_fp4,
|
w1_fp4=w1_fp4,
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
w1_alphas=w1_alphas,
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w2_fp4=w2_fp4,
|
w2_fp4=w2_fp4,
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
w2_alphas=w2_alphas,
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
m=m,
|
m=m,
|
||||||
n=n,
|
n=n,
|
||||||
k=k,
|
k=k,
|
||||||
e=num_experts,
|
e=num_experts,
|
||||||
device=device,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_triton_from_graph(
|
def run_triton_from_graph(
|
||||||
@ -246,16 +258,18 @@ def bench_run(
|
|||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
):
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_fp8_scale,
|
||||||
|
)
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a_fp8_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def replay_graph(graph, num_repeats):
|
def replay_graph(graph, num_repeats):
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts,
|
fused_experts,
|
||||||
@ -96,6 +97,11 @@ def bench_run(
|
|||||||
a_scale: torch.Tensor,
|
a_scale: torch.Tensor,
|
||||||
num_repeats: int,
|
num_repeats: int,
|
||||||
):
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_scale,
|
||||||
|
)
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
fused_experts(
|
fused_experts(
|
||||||
a,
|
a,
|
||||||
@ -103,10 +109,7 @@ def bench_run(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cutlass_moe(
|
def run_cutlass_moe(
|
||||||
@ -125,6 +128,12 @@ def bench_run(
|
|||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
num_repeats: int,
|
num_repeats: int,
|
||||||
):
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
)
|
||||||
|
|
||||||
for _ in range(num_repeats):
|
for _ in range(num_repeats):
|
||||||
cutlass_moe_fp8(
|
cutlass_moe_fp8(
|
||||||
a,
|
a,
|
||||||
@ -132,14 +141,11 @@ def bench_run(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
ab_strides1,
|
ab_strides1,
|
||||||
ab_strides2,
|
ab_strides2,
|
||||||
c_strides1,
|
c_strides1,
|
||||||
c_strides2,
|
c_strides2,
|
||||||
per_act_token,
|
quant_config=quant_config,
|
||||||
a1_scale=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_cutlass_from_graph(
|
def run_cutlass_from_graph(
|
||||||
@ -156,6 +162,12 @@ def bench_run(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
)
|
||||||
|
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
):
|
):
|
||||||
@ -165,14 +177,11 @@ def bench_run(
|
|||||||
w2_q,
|
w2_q,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
ab_strides1,
|
ab_strides1,
|
||||||
ab_strides2,
|
ab_strides2,
|
||||||
c_strides1,
|
c_strides1,
|
||||||
c_strides2,
|
c_strides2,
|
||||||
per_act_token,
|
quant_config=quant_config,
|
||||||
a1_scale=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_triton_from_graph(
|
def run_triton_from_graph(
|
||||||
@ -185,6 +194,11 @@ def bench_run(
|
|||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
a_scale: torch.Tensor,
|
a_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_scale,
|
||||||
|
)
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
):
|
):
|
||||||
@ -194,10 +208,7 @@ def bench_run(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def replay_graph(graph, num_repeats):
|
def replay_graph(graph, num_repeats):
|
||||||
|
|||||||
@ -14,6 +14,10 @@ import ray
|
|||||||
import torch
|
import torch
|
||||||
from ray.experimental.tqdm_ray import tqdm
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
_get_config_dtype_str,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import get_config
|
from vllm.transformers_utils.config import get_config
|
||||||
@ -134,43 +138,36 @@ def benchmark_config(
|
|||||||
def run():
|
def run():
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
quant_dtype = torch.int8
|
||||||
|
else:
|
||||||
|
quant_dtype = None
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_quant_shape,
|
||||||
|
)
|
||||||
|
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
if use_deep_gemm:
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
x, input_gating, topk, renormalize=not use_deep_gemm
|
||||||
x, input_gating, topk, False
|
)
|
||||||
)
|
return fused_experts(
|
||||||
return fused_experts(
|
x,
|
||||||
x,
|
w1,
|
||||||
w1,
|
w2,
|
||||||
w2,
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
inplace=True,
|
||||||
inplace=True,
|
quant_config=quant_config,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
allow_deep_gemm=use_deep_gemm,
|
||||||
w1_scale=w1_scale,
|
)
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
block_shape=block_quant_shape,
|
|
||||||
allow_deep_gemm=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
fused_moe(
|
|
||||||
x,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
input_gating,
|
|
||||||
topk,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
block_shape=block_quant_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
run()
|
run()
|
||||||
@ -414,7 +411,7 @@ class BenchmarkWorker:
|
|||||||
use_deep_gemm: bool = False,
|
use_deep_gemm: bool = False,
|
||||||
) -> tuple[dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
dtype_str = get_config_dtype_str(
|
dtype_str = _get_config_dtype_str(
|
||||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
)
|
)
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
@ -547,7 +544,7 @@ def save_configs(
|
|||||||
block_quant_shape: list[int],
|
block_quant_shape: list[int],
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
dtype_str = get_config_dtype_str(
|
dtype_str = _get_config_dtype_str(
|
||||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||||
|
|
||||||
from .mk_objects import (expert_info, make_fused_experts,
|
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
|
||||||
make_prepare_finalize, prepare_finalize_info)
|
make_prepare_finalize, prepare_finalize_info)
|
||||||
from .parallel_utils import ProcessGroupInfo
|
from .parallel_utils import ProcessGroupInfo
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class Config:
|
|||||||
E: int
|
E: int
|
||||||
topks: Union[list[int], int]
|
topks: Union[list[int], int]
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
quant_config: Optional[FusedMoEQuantConfig]
|
quant_config: Optional[TestMoEQuantConfig]
|
||||||
|
|
||||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||||
@ -52,7 +52,7 @@ class Config:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.quant_config is None:
|
if self.quant_config is None:
|
||||||
self.quant_config = FusedMoEQuantConfig()
|
self.quant_config = TestMoEQuantConfig(None, False, False, None)
|
||||||
|
|
||||||
def describe(self) -> str:
|
def describe(self) -> str:
|
||||||
s = ""
|
s = ""
|
||||||
@ -275,21 +275,19 @@ class WeightTensors:
|
|||||||
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||||
|
|
||||||
def to_current_device(self):
|
def to_current_device(self):
|
||||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
device = torch.cuda.current_device()
|
||||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
self.w1 = self.w1.to(device=device)
|
||||||
|
self.w2 = self.w2.to(device=device)
|
||||||
|
|
||||||
if self.is_quantized():
|
if self.w1_scale is not None:
|
||||||
assert self.w1_scale is not None
|
self.w1_scale = self.w1_scale.to(device=device)
|
||||||
assert self.w2_scale is not None
|
if self.w2_scale is not None:
|
||||||
self.w1_scale = self.w1_scale.to(
|
self.w2_scale = self.w2_scale.to(device=device)
|
||||||
device=torch.cuda.current_device())
|
|
||||||
self.w2_scale = self.w2_scale.to(
|
|
||||||
device=torch.cuda.current_device())
|
|
||||||
|
|
||||||
if self.w1_gs is not None:
|
if self.w1_gs is not None:
|
||||||
assert self.w2_gs is not None
|
self.w1_gs = self.w1_gs.to(device=device)
|
||||||
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
|
if self.w2_gs is not None:
|
||||||
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
|
self.w2_gs = self.w2_gs.to(device=device)
|
||||||
|
|
||||||
def slice_weights(self, rank: int,
|
def slice_weights(self, rank: int,
|
||||||
num_local_experts: int) -> "WeightTensors":
|
num_local_experts: int) -> "WeightTensors":
|
||||||
@ -297,20 +295,12 @@ class WeightTensors:
|
|||||||
e = s + num_local_experts
|
e = s + num_local_experts
|
||||||
w1 = self.w1[s:e, :, :]
|
w1 = self.w1[s:e, :, :]
|
||||||
w2 = self.w2[s:e, :, :]
|
w2 = self.w2[s:e, :, :]
|
||||||
|
w1_scale = self.w1_scale[
|
||||||
w1_scale, w2_scale = (None, None)
|
s:e, :, :] if self.w1_scale is not None else None
|
||||||
if self.is_quantized():
|
w2_scale = self.w2_scale[
|
||||||
assert self.w1_scale is not None
|
s:e, :, :] if self.w2_scale is not None else None
|
||||||
assert self.w2_scale is not None
|
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||||
w1_scale = self.w1_scale[s:e, :, :]
|
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||||
w2_scale = self.w2_scale[s:e, :, :]
|
|
||||||
|
|
||||||
w1_gs = self.w1_gs
|
|
||||||
w2_gs = self.w2_gs
|
|
||||||
if w1_gs is not None:
|
|
||||||
assert w2_gs is not None
|
|
||||||
w1_gs = w1_gs[s:e]
|
|
||||||
w2_gs = w2_gs[s:e]
|
|
||||||
|
|
||||||
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||||
|
|
||||||
@ -323,7 +313,8 @@ class WeightTensors:
|
|||||||
in_dtype=config.dtype,
|
in_dtype=config.dtype,
|
||||||
quant_dtype=config.quant_dtype,
|
quant_dtype=config.quant_dtype,
|
||||||
block_shape=config.quant_block_shape,
|
block_shape=config.quant_block_shape,
|
||||||
per_act_token_quant=config.is_per_out_ch_quant,
|
per_out_ch_quant=config.
|
||||||
|
is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||||
)
|
)
|
||||||
return WeightTensors(w1=w1,
|
return WeightTensors(w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
@ -342,8 +333,6 @@ class RankTensors:
|
|||||||
topk_ids: torch.Tensor
|
topk_ids: torch.Tensor
|
||||||
expert_map: Optional[torch.Tensor]
|
expert_map: Optional[torch.Tensor]
|
||||||
|
|
||||||
quant_config: Optional[FusedMoEQuantConfig]
|
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
s = ""
|
s = ""
|
||||||
s += "== Rank Tensors: \n"
|
s += "== Rank Tensors: \n"
|
||||||
@ -426,7 +415,6 @@ class RankTensors:
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
quant_config=config.quant_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
|||||||
and config.supports_apply_weight_on_input())
|
and config.supports_apply_weight_on_input())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||||
|
return torch.ones((num_experts, ),
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
def make_modular_kernel(
|
def make_modular_kernel(
|
||||||
config: Config,
|
config: Config,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
weights: WeightTensors,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.FusedMoEModularKernel:
|
) -> mk.FusedMoEModularKernel:
|
||||||
|
|
||||||
def next_power_of_2(x):
|
def next_power_of_2(x):
|
||||||
@ -548,20 +542,20 @@ def make_modular_kernel(
|
|||||||
num_local_experts=config.num_local_experts,
|
num_local_experts=config.num_local_experts,
|
||||||
moe_parallel_config=moe_parallel_config,
|
moe_parallel_config=moe_parallel_config,
|
||||||
in_dtype=config.dtype,
|
in_dtype=config.dtype,
|
||||||
quant_config=config.quant_config,
|
|
||||||
max_num_tokens=next_power_of_2(config.M),
|
max_num_tokens=next_power_of_2(config.M),
|
||||||
)
|
)
|
||||||
|
|
||||||
# make modular kernel
|
# make modular kernel
|
||||||
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||||
config.all2all_backend(), moe)
|
config.all2all_backend(), moe,
|
||||||
|
quant_config)
|
||||||
|
|
||||||
fused_experts = make_fused_experts(
|
fused_experts = make_fused_experts(
|
||||||
config.fused_experts_type,
|
config.fused_experts_type,
|
||||||
moe,
|
moe,
|
||||||
|
quant_config,
|
||||||
prepare_finalize.num_dispatchers(),
|
prepare_finalize.num_dispatchers(),
|
||||||
weights.w1_gs,
|
config.N,
|
||||||
weights.w2_gs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
modular_kernel = mk.FusedMoEModularKernel(
|
modular_kernel = mk.FusedMoEModularKernel(
|
||||||
@ -583,12 +577,38 @@ def run_modular_kernel(
|
|||||||
# weights for rank
|
# weights for rank
|
||||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||||
|
|
||||||
mk = make_modular_kernel(config, vllm_config, weights)
|
if config.quant_dtype == "nvfp4":
|
||||||
|
gscale = _make_gscale(config.num_local_experts)
|
||||||
|
else:
|
||||||
|
gscale = None
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
config.quant_dtype,
|
||||||
|
w1_scale=rank_weights.w1_scale,
|
||||||
|
w2_scale=rank_weights.w2_scale,
|
||||||
|
a1_scale=rank_tensors.hidden_states_scale,
|
||||||
|
g1_alphas=(1 / rank_weights.w1_gs)
|
||||||
|
if rank_weights.w1_gs is not None else None,
|
||||||
|
g2_alphas=(1 / rank_weights.w2_gs)
|
||||||
|
if rank_weights.w2_gs is not None else None,
|
||||||
|
a1_gscale=gscale,
|
||||||
|
a2_gscale=gscale,
|
||||||
|
block_shape=config.quant_block_shape,
|
||||||
|
per_act_token_quant=config.is_per_act_token_quant,
|
||||||
|
per_out_ch_quant=config.is_per_out_ch_quant,
|
||||||
|
)
|
||||||
|
|
||||||
|
mk = make_modular_kernel(config, vllm_config, quant_config)
|
||||||
|
|
||||||
|
# impls might update the tensor in place
|
||||||
|
hidden_states = rank_tensors.hidden_states.clone()
|
||||||
|
|
||||||
|
topk_ids = rank_tensors.topk_ids.to(
|
||||||
|
mk.prepare_finalize.topk_indices_dtype())
|
||||||
|
|
||||||
mk_kwargs = {
|
mk_kwargs = {
|
||||||
"hidden_states":
|
"hidden_states":
|
||||||
rank_tensors.hidden_states.clone(
|
hidden_states,
|
||||||
), # impls might update the tensor in place
|
|
||||||
"w1":
|
"w1":
|
||||||
rank_weights.w1,
|
rank_weights.w1,
|
||||||
"w2":
|
"w2":
|
||||||
@ -596,15 +616,9 @@ def run_modular_kernel(
|
|||||||
"topk_weights":
|
"topk_weights":
|
||||||
rank_tensors.topk_weights,
|
rank_tensors.topk_weights,
|
||||||
"topk_ids":
|
"topk_ids":
|
||||||
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
|
topk_ids,
|
||||||
"expert_map":
|
"expert_map":
|
||||||
rank_tensors.expert_map,
|
rank_tensors.expert_map,
|
||||||
"w1_scale":
|
|
||||||
rank_weights.w1_scale,
|
|
||||||
"w2_scale":
|
|
||||||
rank_weights.w2_scale,
|
|
||||||
"a1_scale":
|
|
||||||
rank_tensors.hidden_states_scale,
|
|
||||||
"global_num_experts":
|
"global_num_experts":
|
||||||
config.E,
|
config.E,
|
||||||
"apply_router_weight_on_input":
|
"apply_router_weight_on_input":
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||||
@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
|
|||||||
quant_config_dict = config_dict['quant_config']
|
quant_config_dict = config_dict['quant_config']
|
||||||
del config_dict['quant_config']
|
del config_dict['quant_config']
|
||||||
if quant_config_dict is None:
|
if quant_config_dict is None:
|
||||||
quant_config = FusedMoEQuantConfig(None)
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
quant_config_dict = asdict(quant_config)
|
quant_config_dict = asdict(quant_config)
|
||||||
|
|
||||||
config_dict |= quant_config_dict
|
config_dict |= quant_config_dict
|
||||||
|
|||||||
@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
|
|||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestMoEQuantConfig:
|
||||||
|
quant_dtype: Union[torch.dtype, str, None]
|
||||||
|
per_out_ch_quant: bool
|
||||||
|
per_act_token_quant: bool
|
||||||
|
block_shape: Optional[list[int]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PrepareFinalizeInfo:
|
class PrepareFinalizeInfo:
|
||||||
activation_format: mk.FusedMoEActivationFormat
|
activation_format: mk.FusedMoEActivationFormat
|
||||||
@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
|
|||||||
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
||||||
]
|
]
|
||||||
common_float_and_int_types = common_float_types + [torch.int8]
|
common_float_and_int_types = common_float_types + [torch.int8]
|
||||||
nv_fp4_types = ["nvfp4"]
|
nvfp4_types = ["nvfp4"]
|
||||||
fp8_types = [torch.float8_e4m3fn]
|
fp8_types = [torch.float8_e4m3fn]
|
||||||
|
|
||||||
|
|
||||||
@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
|
|||||||
register_prepare_and_finalize(
|
register_prepare_and_finalize(
|
||||||
FlashInferCutlassMoEPrepareAndFinalize,
|
FlashInferCutlassMoEPrepareAndFinalize,
|
||||||
standard_format,
|
standard_format,
|
||||||
nv_fp4_types,
|
nvfp4_types,
|
||||||
blocked_quantization_support=True,
|
blocked_quantization_support=True,
|
||||||
backend=None,
|
backend=None,
|
||||||
force_multigpu=True,
|
force_multigpu=True,
|
||||||
@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
|
|||||||
register_experts(
|
register_experts(
|
||||||
FlashInferExperts,
|
FlashInferExperts,
|
||||||
standard_format,
|
standard_format,
|
||||||
nv_fp4_types,
|
nvfp4_types,
|
||||||
blocked_quantization_support=True,
|
blocked_quantization_support=True,
|
||||||
supports_chunking=True,
|
supports_chunking=True,
|
||||||
# Note: this is a hack to get it to run for now
|
# Note: this is a hack to get it to run for now
|
||||||
@ -306,39 +314,39 @@ if cutlass_fp4_supported():
|
|||||||
register_experts(
|
register_experts(
|
||||||
CutlassExpertsFp4,
|
CutlassExpertsFp4,
|
||||||
standard_format,
|
standard_format,
|
||||||
nv_fp4_types,
|
nvfp4_types,
|
||||||
blocked_quantization_support=True,
|
blocked_quantization_support=True,
|
||||||
supports_chunking=True,
|
supports_chunking=True,
|
||||||
supports_expert_map=False,
|
supports_expert_map=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
MK_QUANT_CONFIGS = [
|
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||||
None,
|
None,
|
||||||
# per-channel / per-column weights and per-tensor activations
|
# per-channel / per-column weights and per-tensor activations
|
||||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||||
per_out_ch_quant=True,
|
per_out_ch_quant=True,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=None),
|
block_shape=None),
|
||||||
# per-channel / per-column weights and per-token activations
|
# per-channel / per-column weights and per-token activations
|
||||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||||
per_out_ch_quant=True,
|
per_out_ch_quant=True,
|
||||||
per_act_token_quant=True,
|
per_act_token_quant=True,
|
||||||
block_shape=None),
|
block_shape=None),
|
||||||
# per-tensor weights and per-tensor activations
|
# per-tensor weights and per-tensor activations
|
||||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||||
per_out_ch_quant=False,
|
per_out_ch_quant=False,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=None),
|
block_shape=None),
|
||||||
# per-tensor weights and per-token activations
|
# per-tensor weights and per-token activations
|
||||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||||
per_out_ch_quant=False,
|
per_out_ch_quant=False,
|
||||||
per_act_token_quant=True,
|
per_act_token_quant=True,
|
||||||
block_shape=None),
|
block_shape=None),
|
||||||
# block-quantized weights and 128 block per-token activations
|
# block-quantized weights and 128 block per-token activations
|
||||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||||
per_out_ch_quant=False,
|
per_out_ch_quant=False,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=[128, 128]),
|
block_shape=[128, 128]),
|
||||||
# TODO (varun) : Should we test the following combinations ?
|
# TODO (varun) : Should we test the following combinations ?
|
||||||
# block-quantized weights and per-token activations
|
# block-quantized weights and per-token activations
|
||||||
# block-quantized weights and per-tensor activations
|
# block-quantized weights and per-tensor activations
|
||||||
@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
|
|||||||
|
|
||||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||||
MK_QUANT_CONFIGS += [
|
MK_QUANT_CONFIGS += [
|
||||||
FusedMoEQuantConfig(quant_dtype="nvfp4",
|
TestMoEQuantConfig(quant_dtype="nvfp4",
|
||||||
per_out_ch_quant=False,
|
per_out_ch_quant=False,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=None),
|
block_shape=None),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
|
||||||
return torch.ones((num_experts, ),
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def make_prepare_finalize(
|
def make_prepare_finalize(
|
||||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||||
backend: Optional[str],
|
backend: Optional[str],
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.FusedMoEPrepareAndFinalize:
|
) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
if backend != "naive" and backend is not None:
|
if backend != "naive" and backend is not None:
|
||||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||||
|
moe, quant_config)
|
||||||
assert prepare_finalize is not None
|
assert prepare_finalize is not None
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
use_dp=moe.moe_parallel_config.dp_size > 1)
|
||||||
a1_gscale=_make_gscale(moe.num_local_experts),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return MoEPrepareAndFinalizeNoEP()
|
return MoEPrepareAndFinalizeNoEP()
|
||||||
|
|
||||||
@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
|||||||
return t[s:e]
|
return t[s:e]
|
||||||
|
|
||||||
|
|
||||||
|
def make_cutlass_strides(
|
||||||
|
e: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||||
|
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
||||||
|
|
||||||
|
|
||||||
def make_fused_experts(
|
def make_fused_experts(
|
||||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
w1_gs: Optional[torch.Tensor],
|
N: int,
|
||||||
w2_gs: Optional[torch.Tensor],
|
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
|
||||||
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
|
|
||||||
batch_kwargs = {
|
batch_kwargs = {
|
||||||
"max_num_tokens": moe.max_num_tokens,
|
"max_num_tokens": moe.max_num_tokens,
|
||||||
"num_dispatchers": num_dispatchers,
|
"num_dispatchers": num_dispatchers,
|
||||||
}
|
}
|
||||||
quant_kwargs = {
|
quant_kwargs = {
|
||||||
"use_fp8_w8a8": use_fp8,
|
"quant_config": quant_config,
|
||||||
"use_int8_w8a8": False,
|
|
||||||
"use_int8_w8a16": False,
|
|
||||||
"use_int4_w4a16": False,
|
|
||||||
"block_shape": moe.block_shape,
|
|
||||||
"per_act_token_quant": moe.per_act_token_quant,
|
|
||||||
}
|
}
|
||||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||||
|
|
||||||
|
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
|
||||||
|
|
||||||
if fused_experts_type == BatchedDeepGemmExperts:
|
if fused_experts_type == BatchedDeepGemmExperts:
|
||||||
kwargs = batch_kwargs | {
|
kwargs = batch_kwargs | quant_kwargs
|
||||||
"block_shape": moe.block_shape,
|
|
||||||
"per_act_token_quant": moe.per_act_token_quant,
|
|
||||||
}
|
|
||||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||||
experts = BatchedDeepGemmExperts(**kwargs)
|
experts = BatchedDeepGemmExperts(**kwargs)
|
||||||
elif fused_experts_type == BatchedTritonExperts:
|
elif fused_experts_type == BatchedTritonExperts:
|
||||||
@ -422,8 +429,8 @@ def make_fused_experts(
|
|||||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||||
elif fused_experts_type == DeepGemmExperts:
|
elif fused_experts_type == DeepGemmExperts:
|
||||||
print("Making DeepGemmExperts () ...")
|
print("Making DeepGemmExperts {quant_config} ...")
|
||||||
experts = DeepGemmExperts()
|
experts = DeepGemmExperts(quant_config)
|
||||||
elif fused_experts_type == TritonExperts:
|
elif fused_experts_type == TritonExperts:
|
||||||
kwargs = quant_kwargs
|
kwargs = quant_kwargs
|
||||||
print(f"Making TritonExperts {kwargs} ...")
|
print(f"Making TritonExperts {kwargs} ...")
|
||||||
@ -437,62 +444,50 @@ def make_fused_experts(
|
|||||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||||
experts = NaiveBatchedExperts(**kwargs)
|
experts = NaiveBatchedExperts(**kwargs)
|
||||||
elif fused_experts_type == CutlassExpertsFp8:
|
elif fused_experts_type == CutlassExpertsFp8:
|
||||||
|
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"out_dtype": moe.in_dtype,
|
"out_dtype": moe.in_dtype,
|
||||||
"per_act_token_quant": moe.per_act_token_quant,
|
"ab_strides1": strides[0],
|
||||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
"ab_strides2": strides[1],
|
||||||
"block_shape": moe.block_shape,
|
"c_strides1": strides[2],
|
||||||
}
|
"c_strides2": strides[3],
|
||||||
|
} | quant_kwargs
|
||||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||||
experts = CutlassExpertsFp8(**kwargs)
|
experts = CutlassExpertsFp8(**kwargs)
|
||||||
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||||
|
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"max_experts_per_worker": moe.num_local_experts,
|
"max_experts_per_worker": moe.num_local_experts,
|
||||||
"num_dispatchers": num_dispatchers,
|
"num_dispatchers": num_dispatchers,
|
||||||
"out_dtype": moe.in_dtype,
|
"out_dtype": moe.in_dtype,
|
||||||
"per_act_token_quant": moe.per_act_token_quant,
|
"ab_strides1": strides[0],
|
||||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
"ab_strides2": strides[1],
|
||||||
"block_shape": moe.block_shape,
|
"c_strides1": strides[2],
|
||||||
}
|
"c_strides2": strides[3],
|
||||||
|
} | quant_kwargs
|
||||||
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||||
experts = CutlassBatchedExpertsFp8(**kwargs)
|
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||||
elif fused_experts_type == CutlassExpertsFp4:
|
elif fused_experts_type == CutlassExpertsFp4:
|
||||||
assert w1_gs is not None and w2_gs is not None
|
|
||||||
num_experts = moe.num_local_experts
|
|
||||||
rank = moe.moe_parallel_config.dp_rank
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
"max_experts_per_worker": moe.num_local_experts,
|
||||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
|
||||||
"a1_gscale": _make_gscale(num_experts),
|
|
||||||
"a2_gscale": _make_gscale(num_experts),
|
|
||||||
"max_experts_per_worker": num_experts,
|
|
||||||
"out_dtype": moe.in_dtype,
|
|
||||||
"per_act_token_quant": moe.per_act_token_quant,
|
|
||||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
|
||||||
"block_shape": moe.block_shape,
|
|
||||||
"num_dispatchers": num_dispatchers,
|
"num_dispatchers": num_dispatchers,
|
||||||
}
|
"out_dtype": moe.in_dtype,
|
||||||
|
} | quant_kwargs
|
||||||
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||||
experts = CutlassExpertsFp4(**kwargs)
|
experts = CutlassExpertsFp4(**kwargs)
|
||||||
elif fused_experts_type == FlashInferExperts:
|
elif fused_experts_type == FlashInferExperts:
|
||||||
assert w1_gs is not None and w2_gs is not None
|
|
||||||
num_experts = moe.num_local_experts
|
|
||||||
rank = moe.moe_parallel_config.dp_rank
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
|
||||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
|
||||||
"a1_gscale": _make_gscale(num_experts),
|
|
||||||
"a2_gscale": _make_gscale(num_experts),
|
|
||||||
"out_dtype": moe.in_dtype,
|
"out_dtype": moe.in_dtype,
|
||||||
"quant_dtype": "nvfp4",
|
|
||||||
"ep_rank": moe.ep_rank,
|
"ep_rank": moe.ep_rank,
|
||||||
"ep_size": moe.ep_size,
|
"ep_size": moe.ep_size,
|
||||||
"tp_rank": moe.tp_rank,
|
"tp_rank": moe.tp_rank,
|
||||||
"tp_size": moe.tp_size,
|
"tp_size": moe.tp_size,
|
||||||
}
|
} | quant_kwargs
|
||||||
print(f"Making FlashInferExperts {kwargs} ...")
|
print(f"Making FlashInferExperts {kwargs} ...")
|
||||||
experts = FlashInferExperts(**kwargs)
|
experts = FlashInferExperts(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||||
|
|
||||||
|
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
|
||||||
|
|
||||||
return experts
|
return experts
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
|||||||
rank=0,
|
rank=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
# triton (reference)
|
# triton (reference)
|
||||||
triton_experts = BatchedTritonExperts(
|
triton_experts = BatchedTritonExperts(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
per_act_token_quant=False,
|
|
||||||
block_shape=BLOCK_SIZE,
|
|
||||||
)
|
)
|
||||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||||
|
|
||||||
@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
global_num_experts=E,
|
global_num_experts=E,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
|||||||
deepgemm_experts = BatchedDeepGemmExperts(
|
deepgemm_experts = BatchedDeepGemmExperts(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
block_shape=BLOCK_SIZE,
|
quant_config=quant_config,
|
||||||
per_act_token_quant=False,
|
|
||||||
)
|
)
|
||||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||||
|
|
||||||
@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
global_num_experts=E,
|
global_num_experts=E,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
in_dtype=act_dtype,
|
in_dtype=act_dtype,
|
||||||
quant_dtype=quant_dtype,
|
quant_dtype=quant_dtype,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
per_act_token_quant=per_act_token_quant,
|
per_out_ch_quant=per_act_token_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||||
@ -250,7 +250,7 @@ def test_fused_moe_batched_experts(
|
|||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
in_dtype=act_dtype,
|
in_dtype=act_dtype,
|
||||||
quant_dtype=quant_dtype,
|
quant_dtype=quant_dtype,
|
||||||
per_act_token_quant=per_act_token_quant,
|
per_out_ch_quant=per_act_token_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_scales and quant_dtype is not None:
|
if input_scales and quant_dtype is not None:
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.moe.utils import make_test_weights
|
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||||
native_w8a8_block_matmul)
|
native_w8a8_block_matmul)
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
|||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
w1, w2, quant_config = make_test_quant_config(
|
||||||
_) = make_test_weights(E,
|
E,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
dtype,
|
dtype,
|
||||||
torch.float8_e4m3fn,
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=block_size)
|
block_shape=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
use_mxfp4_w4a4=False,
|
|
||||||
per_act_token_quant=False,
|
|
||||||
block_shape=block_size)
|
|
||||||
|
|
||||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||||
|
|
||||||
@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
|||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
w1_s,
|
quant_config.w1_scale,
|
||||||
w2_s,
|
quant_config.w2_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
block_size,
|
block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = fused_experts(
|
out = fused_experts(a,
|
||||||
a,
|
w1,
|
||||||
w1,
|
w2,
|
||||||
w2,
|
topk_weights,
|
||||||
topk_weights,
|
topk_ids,
|
||||||
topk_ids,
|
quant_config=quant_config)
|
||||||
use_fp8_w8a8=True,
|
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
block_shape=block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
m_out = m_fused_moe(
|
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||||
a,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
# 0.039 only needed for M >= 8192
|
||||||
tol = 0.035 if M < 40000 else 0.039
|
tol = 0.035 if M < 8192 else 0.039
|
||||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||||
|
|
||||||
@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
|||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||||
_) = make_test_weights(E,
|
E,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
dtype,
|
dtype,
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
per_act_token_quant=False,
|
per_out_ch_quant=False,
|
||||||
block_shape=block_size)
|
block_shape=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Note: for now use_compile will error out if the problem size is
|
# Note: for now use_compile will error out if the problem size is
|
||||||
# large enough to trigger chunking. I'm leaving the flag and
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
|
|||||||
@ -4,12 +4,12 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.moe.utils import make_test_weights
|
from tests.kernels.moe.utils import make_test_quant_config
|
||||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||||
native_w8a8_block_matmul)
|
native_w8a8_block_matmul)
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (7, 0):
|
if current_platform.get_device_capability() < (7, 0):
|
||||||
@ -50,7 +50,7 @@ MNK_FACTORS = [
|
|||||||
(2048, 128, 128),
|
(2048, 128, 128),
|
||||||
(2048, 1024, 7168),
|
(2048, 1024, 7168),
|
||||||
(2048, 4096, 512),
|
(2048, 4096, 512),
|
||||||
(2048, 4096, 7168),
|
(2048, 4096, 4096),
|
||||||
]
|
]
|
||||||
|
|
||||||
E = [8, 24]
|
E = [8, 24]
|
||||||
@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|||||||
|
|
||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||||
|
|
||||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
w1, w2, quant_config = make_test_quant_config(
|
||||||
_) = make_test_weights(E,
|
E,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
dtype,
|
dtype,
|
||||||
torch.int8,
|
quant_dtype=torch.int8,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=block_size)
|
block_shape=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
# Set the context to avoid lots of warning spam.
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
out = fused_moe(
|
out = fused_experts(a,
|
||||||
a,
|
w1,
|
||||||
w1,
|
w2,
|
||||||
w2,
|
topk_weights,
|
||||||
score,
|
topk_ids,
|
||||||
topk,
|
quant_config=quant_config)
|
||||||
renormalize=False,
|
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
|
||||||
use_int8_w8a8=True,
|
quant_config.w2_scale, score, topk,
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
block_shape=block_size,
|
|
||||||
)
|
|
||||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
|
||||||
block_size)
|
block_size)
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -9,6 +10,8 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||||
@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
|||||||
def slice_experts():
|
def slice_experts():
|
||||||
slice_params = [
|
slice_params = [
|
||||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||||
"c_strides2", "w1_scale", "w2_scale"
|
"c_strides2"
|
||||||
]
|
]
|
||||||
full_tensors = {
|
full_tensors = {
|
||||||
k: v
|
k: v
|
||||||
@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
|||||||
if k in slice_params and k in cutlass_moe_kwargs
|
if k in slice_params and k in cutlass_moe_kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
quant_config = cutlass_moe_kwargs["quant_config"]
|
||||||
|
|
||||||
for i in range(0, num_experts, num_local_experts):
|
for i in range(0, num_experts, num_local_experts):
|
||||||
s, e = i, i + num_local_experts
|
s, e = i, i + num_local_experts
|
||||||
|
|
||||||
@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
|||||||
for k, t in full_tensors.items():
|
for k, t in full_tensors.items():
|
||||||
cutlass_moe_kwargs[k] = t[s:e]
|
cutlass_moe_kwargs[k] = t[s:e]
|
||||||
|
|
||||||
|
new_quant_config = copy.deepcopy(quant_config)
|
||||||
|
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
|
||||||
|
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
|
||||||
|
|
||||||
|
cutlass_moe_kwargs["quant_config"] = new_quant_config
|
||||||
|
|
||||||
yield cutlass_moe_kwargs
|
yield cutlass_moe_kwargs
|
||||||
|
|
||||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||||
@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
per_act_token: bool,
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||||
assert not any([
|
assert not any([
|
||||||
t is None for t in [
|
t is None for t in [
|
||||||
@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
|||||||
]
|
]
|
||||||
])
|
])
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=moe_tensors.w1_scale,
|
||||||
|
w2_scale=moe_tensors.w2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
# Set to moe_tensors.a_scale iff static scales + per tensor.
|
||||||
|
# This is not currently being tested.
|
||||||
|
a1_scale=None,
|
||||||
|
)
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'a': moe_tensors.a,
|
'a': moe_tensors.a,
|
||||||
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
||||||
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
||||||
'topk_weights': topk_weights,
|
'topk_weights': topk_weights,
|
||||||
'topk_ids': topk_ids,
|
'topk_ids': topk_ids,
|
||||||
'w1_scale': moe_tensors.w1_scale,
|
|
||||||
'w2_scale': moe_tensors.w2_scale,
|
|
||||||
'ab_strides1': moe_tensors.ab_strides1,
|
'ab_strides1': moe_tensors.ab_strides1,
|
||||||
'ab_strides2': moe_tensors.ab_strides2,
|
'ab_strides2': moe_tensors.ab_strides2,
|
||||||
'c_strides1': moe_tensors.c_strides1,
|
'c_strides1': moe_tensors.c_strides1,
|
||||||
'c_strides2': moe_tensors.c_strides2,
|
'c_strides2': moe_tensors.c_strides2,
|
||||||
'per_act_token': per_act_token,
|
'quant_config': quant_config,
|
||||||
'a1_scale': None #moe_tensors.a_scale
|
|
||||||
}
|
}
|
||||||
|
|
||||||
num_experts = moe_tensors.w1.size(0)
|
num_experts = moe_tensors.w1.size(0)
|
||||||
@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
|
|||||||
|
|
||||||
# Note that we are using the dequantized versions of the tensors.
|
# Note that we are using the dequantized versions of the tensors.
|
||||||
# Using a, w1 and w2 directly results in minor output differences.
|
# Using a, w1 and w2 directly results in minor output differences.
|
||||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
|
||||||
topk_ids)
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
|
triton_output = fused_experts(mt.a_d,
|
||||||
|
mt.w1_d,
|
||||||
|
mt.w2_d,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
if ep_size is not None:
|
if ep_size is not None:
|
||||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||||
number_local_experts = e // ep_size
|
number_local_experts = e // ep_size
|
||||||
else:
|
else:
|
||||||
number_local_experts = None
|
number_local_experts = None
|
||||||
|
|
||||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||||
number_local_experts)
|
per_out_ch, number_local_experts)
|
||||||
|
|
||||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||||
# the rest.
|
# the rest.
|
||||||
@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
|||||||
|
|
||||||
# Note that we are using the dequantized versions of the tensors.
|
# Note that we are using the dequantized versions of the tensors.
|
||||||
# Using a, w1 and w2 directly results in minor output differences.
|
# Using a, w1 and w2 directly results in minor output differences.
|
||||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
topk_ids)
|
triton_output = fused_experts(mt.a_d,
|
||||||
|
mt.w1_d,
|
||||||
|
mt.w2_d,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
stream = torch.cuda.Stream()
|
stream = torch.cuda.Stream()
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(graph, stream=stream):
|
with torch.cuda.graph(graph, stream=stream):
|
||||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||||
per_act_token)
|
per_act_token, per_out_ch)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
|
|||||||
@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
|
|||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
|
|||||||
Return weights w1q, w2q, w1_scale, w2_scale
|
Return weights w1q, w2q, w1_scale, w2_scale
|
||||||
"""
|
"""
|
||||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||||
_) = make_test_weights(e, n, k, torch.bfloat16,
|
_) = make_test_weights(e,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
torch.bfloat16,
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
block_size)
|
block_shape=block_size)
|
||||||
return w1q, w2q, w1_scale, w2_scale
|
return w1q, w2q, w1_scale, w2_scale
|
||||||
|
|
||||||
|
|
||||||
@ -130,10 +135,11 @@ class TestTensors:
|
|||||||
config=config)
|
config=config)
|
||||||
|
|
||||||
|
|
||||||
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
def make_ll_modular_kernel(
|
||||||
max_tokens_per_rank: int, dp_size: int,
|
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
|
||||||
hidden_size: int, q_dtype: Optional[torch.dtype],
|
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
test_config: TestConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
assert test_config.low_latency
|
assert test_config.low_latency
|
||||||
assert test_config.use_fp8_dispatch is not None
|
assert test_config.use_fp8_dispatch is not None
|
||||||
@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
fused_experts = BatchedDeepGemmExperts(
|
fused_experts = BatchedDeepGemmExperts(
|
||||||
max_num_tokens=max_tokens_per_rank,
|
max_num_tokens=max_tokens_per_rank,
|
||||||
num_dispatchers=pgi.world_size // dp_size,
|
num_dispatchers=pgi.world_size // dp_size,
|
||||||
block_shape=test_config.block_size,
|
quant_config=quant_config,
|
||||||
per_act_token_quant=test_config.per_act_token_quant)
|
)
|
||||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
fused_experts=fused_experts)
|
fused_experts=fused_experts)
|
||||||
return mk
|
return mk
|
||||||
|
|
||||||
|
|
||||||
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
def make_ht_modular_kernel(
|
||||||
dp_size: int, num_local_experts: int,
|
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||||
q_dtype: Optional[torch.dtype],
|
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
test_config: TestConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
assert not test_config.low_latency
|
assert not test_config.low_latency
|
||||||
assert test_config.use_fp8_dispatch is None
|
assert test_config.use_fp8_dispatch is None
|
||||||
@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
q_dtype=q_dtype,
|
q_dtype=q_dtype,
|
||||||
block_shape=test_config.block_size)
|
block_shape=test_config.block_size)
|
||||||
|
|
||||||
fused_experts = DeepGemmExperts()
|
fused_experts = DeepGemmExperts(quant_config)
|
||||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
fused_experts=fused_experts)
|
fused_experts=fused_experts)
|
||||||
return mk
|
return mk
|
||||||
|
|
||||||
|
|
||||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
def make_modular_kernel(
|
||||||
num_local_experts: int,
|
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||||
test_tensors: TestTensors) -> FusedMoEModularKernel:
|
num_local_experts: int, test_tensors: TestTensors,
|
||||||
|
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
q_dtype = torch.float8_e4m3fn
|
q_dtype = torch.float8_e4m3fn
|
||||||
test_config = test_tensors.config
|
test_config = test_tensors.config
|
||||||
@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
|||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
q_dtype=q_dtype,
|
q_dtype=q_dtype,
|
||||||
test_config=test_config)
|
test_config=test_config,
|
||||||
|
quant_config=quant_config)
|
||||||
else:
|
else:
|
||||||
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
|
mk = make_ht_modular_kernel(pg,
|
||||||
q_dtype, test_config)
|
pgi,
|
||||||
|
dp_size,
|
||||||
|
num_local_experts,
|
||||||
|
q_dtype,
|
||||||
|
test_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
return mk
|
return mk
|
||||||
|
|
||||||
@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
return expert_map.to(device=torch.cuda.current_device(),
|
return expert_map.to(device=torch.cuda.current_device(),
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
# Low-Latency kernels can't dispatch scales.
|
||||||
|
a1_scale=(None if test_config.low_latency else
|
||||||
|
test_tensors.rank_token_scales),
|
||||||
|
block_shape=test_config.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Make modular kernel
|
# Make modular kernel
|
||||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||||
pg=pg,
|
pg=pg,
|
||||||
pgi=pgi,
|
pgi=pgi,
|
||||||
dp_size=dp_size,
|
dp_size=dp_size,
|
||||||
num_local_experts=num_local_experts,
|
num_local_experts=num_local_experts,
|
||||||
test_tensors=test_tensors)
|
test_tensors=test_tensors,
|
||||||
|
quant_config=quant_config)
|
||||||
# Low-Latency kernels can't dispatch scales.
|
|
||||||
a1_scale = (None
|
|
||||||
if test_config.low_latency else test_tensors.rank_token_scales)
|
|
||||||
|
|
||||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
|||||||
activation="silu",
|
activation="silu",
|
||||||
global_num_experts=num_experts,
|
global_num_experts=num_experts,
|
||||||
expert_map=build_expert_map(),
|
expert_map=build_expert_map(),
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=None,
|
|
||||||
w2_zp=None,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=None,
|
|
||||||
apply_router_weight_on_input=False)
|
apply_router_weight_on_input=False)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
|||||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
block_shape=block_shape,
|
|
||||||
# Make sure this is set to False so we
|
# Make sure this is set to False so we
|
||||||
# don't end up comparing the same implementation.
|
# don't end up comparing the same implementation.
|
||||||
allow_deep_gemm=False)
|
allow_deep_gemm=False)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedTritonExperts)
|
BatchedTritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
@ -129,11 +130,9 @@ def make_modular_kernel(
|
|||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
q_dtype: Optional[torch.dtype],
|
q_dtype: Optional[torch.dtype],
|
||||||
use_fp8_dispatch: bool,
|
use_fp8_dispatch: bool,
|
||||||
per_act_token_quant: bool,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> FusedMoEModularKernel:
|
) -> FusedMoEModularKernel:
|
||||||
|
|
||||||
is_quantized = q_dtype is not None
|
|
||||||
|
|
||||||
ht_args: Optional[DeepEPHTArgs] = None
|
ht_args: Optional[DeepEPHTArgs] = None
|
||||||
ll_args: Optional[DeepEPLLArgs] = None
|
ll_args: Optional[DeepEPLLArgs] = None
|
||||||
|
|
||||||
@ -159,24 +158,14 @@ def make_modular_kernel(
|
|||||||
num_dispatchers = pgi.world_size // dp_size
|
num_dispatchers = pgi.world_size // dp_size
|
||||||
|
|
||||||
if low_latency_mode:
|
if low_latency_mode:
|
||||||
assert not per_act_token_quant, "not supported in ll mode"
|
assert not quant_config.per_act_token_quant, "not supported in ll mode"
|
||||||
fused_experts = BatchedTritonExperts(
|
fused_experts = BatchedTritonExperts(
|
||||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||||
num_dispatchers=num_dispatchers,
|
num_dispatchers=num_dispatchers,
|
||||||
use_fp8_w8a8=is_quantized,
|
quant_config=quant_config,
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
per_act_token_quant=False,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
fused_experts = TritonExperts(
|
fused_experts = TritonExperts(quant_config=quant_config)
|
||||||
use_fp8_w8a8=is_quantized,
|
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
)
|
|
||||||
|
|
||||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||||
fused_experts=fused_experts)
|
fused_experts=fused_experts)
|
||||||
@ -217,11 +206,6 @@ def deep_ep_moe_impl(
|
|||||||
if is_quantized:
|
if is_quantized:
|
||||||
q_dtype = torch.float8_e4m3fn
|
q_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
# Make modular kernel
|
|
||||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
|
||||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
|
||||||
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
|
||||||
|
|
||||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||||
|
|
||||||
@ -236,6 +220,19 @@ def deep_ep_moe_impl(
|
|||||||
rank_token_scales_chunk = rank_token_scales_chunk[
|
rank_token_scales_chunk = rank_token_scales_chunk[
|
||||||
chunk_start:chunk_end]
|
chunk_start:chunk_end]
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
q_dtype,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
a1_scale=rank_token_scales_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make modular kernel
|
||||||
|
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||||
|
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||||
|
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
|
||||||
|
|
||||||
out = mk.forward(hidden_states=rank_tokens_chunk,
|
out = mk.forward(hidden_states=rank_tokens_chunk,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
@ -245,12 +242,6 @@ def deep_ep_moe_impl(
|
|||||||
activation="silu",
|
activation="silu",
|
||||||
global_num_experts=num_experts,
|
global_num_experts=num_experts,
|
||||||
expert_map=build_expert_map(),
|
expert_map=build_expert_map(),
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=None,
|
|
||||||
w2_zp=None,
|
|
||||||
a1_scale=rank_token_scales_chunk,
|
|
||||||
a2_scale=None,
|
|
||||||
apply_router_weight_on_input=False)
|
apply_router_weight_on_input=False)
|
||||||
|
|
||||||
if not skip_result_store:
|
if not skip_result_store:
|
||||||
@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("mnk", MNKs)
|
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||||
@pytest.mark.parametrize("num_experts", [32])
|
@pytest.mark.parametrize("num_experts", [32])
|
||||||
@pytest.mark.parametrize("topk", [6])
|
@pytest.mark.parametrize("topk", [6])
|
||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
|||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
def test_deep_ep_moe(
|
def test_deep_ep_moe(
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
mnk: tuple[int, int, int],
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
world_dp_size: tuple[int, int],
|
world_dp_size: tuple[int, int],
|
||||||
@ -424,7 +417,6 @@ def test_deep_ep_moe(
|
|||||||
):
|
):
|
||||||
low_latency_mode = False
|
low_latency_mode = False
|
||||||
use_fp8_dispatch = False
|
use_fp8_dispatch = False
|
||||||
m, n, k = mnk
|
|
||||||
|
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
world_size, dp_size = world_dp_size
|
world_size, dp_size = world_dp_size
|
||||||
@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("mnk", MNKs)
|
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||||
@pytest.mark.parametrize("num_experts", [32])
|
@pytest.mark.parametrize("num_experts", [32])
|
||||||
@pytest.mark.parametrize("topk", [6])
|
@pytest.mark.parametrize("topk", [6])
|
||||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@requires_deep_ep
|
@requires_deep_ep
|
||||||
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
def test_low_latency_deep_ep_moe(
|
||||||
num_experts: int, topk: int,
|
dtype: torch.dtype,
|
||||||
world_dp_size: tuple[int, int],
|
m: int,
|
||||||
use_fp8_dispatch: bool):
|
n: int,
|
||||||
|
k: int,
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
world_dp_size: tuple[int, int],
|
||||||
|
use_fp8_dispatch: bool,
|
||||||
|
):
|
||||||
low_latency_mode = True
|
low_latency_mode = True
|
||||||
m, n, k = mnk
|
|
||||||
|
|
||||||
if (low_latency_mode
|
if (low_latency_mode
|
||||||
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
|
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
|
||||||
|
|||||||
@ -11,6 +11,8 @@ import math
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config)
|
||||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
|||||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
block_shape=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# triton reference
|
# triton reference
|
||||||
out_triton = fused_experts(
|
out_triton = fused_experts(
|
||||||
hidden_states=tokens_bf16,
|
hidden_states=tokens_bf16,
|
||||||
@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
block_shape=block_size,
|
|
||||||
allow_deep_gemm=False,
|
allow_deep_gemm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
use_fp8_w8a8=True,
|
quant_config=quant_config,
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
block_shape=block_size,
|
|
||||||
allow_deep_gemm=True,
|
allow_deep_gemm=True,
|
||||||
)
|
)
|
||||||
diff = calc_diff(out_deepgemm, out_triton)
|
diff = calc_diff(out_deepgemm, out_triton)
|
||||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||||
|
|
||||||
|
|
||||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||||
# can trigger the deepgemm path.
|
|
||||||
MNKs = [
|
MNKs = [
|
||||||
(1024, 768, 128),
|
(1024, 768, 128),
|
||||||
(1024, 768, 512),
|
(1024, 768, 512),
|
||||||
@ -144,15 +144,15 @@ TOPKS = [2, 6]
|
|||||||
NUM_EXPERTS = [32]
|
NUM_EXPERTS = [32]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("mnk", MNKs)
|
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||||
@pytest.mark.parametrize("topk", TOPKS)
|
@pytest.mark.parametrize("topk", TOPKS)
|
||||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||||
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||||
reason="Requires deep_gemm kernels")
|
reason="Requires deep_gemm kernels")
|
||||||
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as mp:
|
||||||
m.setenv("VLLM_USE_DEEP_GEMM", "1")
|
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||||
|
|
||||||
_fused_moe_mod = importlib.import_module(
|
_fused_moe_mod = importlib.import_module(
|
||||||
"vllm.model_executor.layers.fused_moe.fused_moe")
|
"vllm.model_executor.layers.fused_moe.fused_moe")
|
||||||
@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
|||||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
||||||
_spy_deep_gemm_moe_fp8)
|
_spy_deep_gemm_moe_fp8)
|
||||||
|
|
||||||
m, n, k = mnk
|
|
||||||
|
|
||||||
if topk > num_experts:
|
if topk > num_experts:
|
||||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
|||||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||||
scoring_func="softmax")
|
scoring_func="softmax")
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=td.w13_weight_scale,
|
||||||
|
w2_scale=td.w2_weight_scale,
|
||||||
|
a1_scale=td.a1_scale,
|
||||||
|
a2_scale=td.a2_scale,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
)
|
||||||
|
|
||||||
output = fused_experts(
|
output = fused_experts(
|
||||||
td.hidden_states,
|
td.hidden_states,
|
||||||
td.w13_quantized,
|
td.w13_quantized,
|
||||||
@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=False,
|
|
||||||
global_num_experts=e,
|
global_num_experts=e,
|
||||||
expert_map=None,
|
expert_map=None,
|
||||||
w1_scale=td.w13_weight_scale,
|
|
||||||
w2_scale=td.w2_weight_scale,
|
|
||||||
a1_scale=td.a1_scale,
|
|
||||||
a2_scale=td.a2_scale,
|
|
||||||
apply_router_weight_on_input=True,
|
apply_router_weight_on_input=True,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||||
@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
|||||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||||
scoring_func="softmax")
|
scoring_func="softmax")
|
||||||
|
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=td.w13_weight_scale,
|
||||||
|
w2_scale=td.w2_weight_scale,
|
||||||
|
a1_scale=td.a1_scale,
|
||||||
|
a2_scale=td.a2_scale,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
)
|
||||||
|
|
||||||
output = fused_experts(
|
output = fused_experts(
|
||||||
td.hidden_states,
|
td.hidden_states,
|
||||||
td.w13_quantized,
|
td.w13_quantized,
|
||||||
@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
activation="silu",
|
activation="silu",
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=False,
|
|
||||||
global_num_experts=e,
|
global_num_experts=e,
|
||||||
expert_map=None,
|
expert_map=None,
|
||||||
w1_scale=td.w13_weight_scale,
|
|
||||||
w2_scale=td.w2_weight_scale,
|
|
||||||
a1_scale=td.a1_scale,
|
|
||||||
a2_scale=td.a2_scale,
|
|
||||||
apply_router_weight_on_input=True,
|
apply_router_weight_on_input=True,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
td.layer.dp_size = 1
|
td.layer.dp_size = 1
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.moe.utils import make_test_weights
|
from tests.kernels.moe.utils import make_test_quant_config
|
||||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
FLOAT8_E4M3_MAX,
|
FLOAT8_E4M3_MAX,
|
||||||
dequantize_nvfp4_to_dtype)
|
dequantize_nvfp4_to_dtype)
|
||||||
@ -41,7 +41,6 @@ MNK_FACTORS = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||||
#@pytest.mark.parametrize("e", [128, 256])
|
|
||||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
|
|
||||||
quant_blocksize = 16
|
quant_blocksize = 16
|
||||||
|
|
||||||
(_, w1_q, w1_blockscale,
|
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
e,
|
||||||
e,
|
n,
|
||||||
n,
|
k,
|
||||||
k,
|
in_dtype=dtype,
|
||||||
in_dtype=dtype,
|
quant_dtype="nvfp4",
|
||||||
quant_dtype="nvfp4",
|
block_shape=None,
|
||||||
block_shape=None, # use quant_blocksize?
|
per_act_token_quant=False,
|
||||||
per_act_token_quant=False,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
topk_weights, topk_ids, _ = fused_topk(a,
|
topk_weights, topk_ids, _ = fused_topk(a,
|
||||||
@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
topk,
|
topk,
|
||||||
renormalize=False)
|
renormalize=False)
|
||||||
|
|
||||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
|
||||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
|
||||||
|
|
||||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||||
|
|
||||||
assert w1_gs is not None
|
|
||||||
assert w2_gs is not None
|
|
||||||
assert w1_blockscale is not None
|
|
||||||
assert w2_blockscale is not None
|
|
||||||
|
|
||||||
flashinfer_experts = FusedMoEModularKernel(
|
flashinfer_experts = FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
FlashInferExperts(
|
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||||
a1_gscale=a1_gs,
|
)
|
||||||
g1_alphas=(1 / w1_gs),
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
g2_alphas=(1 / w2_gs),
|
|
||||||
out_dtype=dtype,
|
|
||||||
quant_dtype="nvfp4",
|
|
||||||
))
|
|
||||||
|
|
||||||
flashinfer_output = flashinfer_experts(
|
flashinfer_output = flashinfer_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
w1=w1_q,
|
w1=w1_q,
|
||||||
w1_scale=w1_blockscale,
|
|
||||||
w2=w2_q,
|
w2=w2_q,
|
||||||
w2_scale=w2_blockscale,
|
|
||||||
a1_scale=a1_gs,
|
|
||||||
a2_scale=a2_gs,
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
)
|
)
|
||||||
@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
for idx in range(0, e):
|
for idx in range(0, e):
|
||||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||||
w1_blockscale[idx],
|
w1_q[idx],
|
||||||
w1_gs[idx],
|
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=w1_q.device,
|
device=w1_q.device,
|
||||||
block_size=quant_blocksize)
|
block_size=quant_blocksize)
|
||||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||||
w2_blockscale[idx],
|
w2_q[idx],
|
||||||
w2_gs[idx],
|
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=w2_q.device,
|
device=w2_q.device,
|
||||||
block_size=quant_blocksize)
|
block_size=quant_blocksize)
|
||||||
|
|
||||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
|||||||
from triton_kernels.tensor_details import layout
|
from triton_kernels.tensor_details import layout
|
||||||
from triton_kernels.testing import assert_close
|
from triton_kernels.testing import assert_close
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedPrepareAndFinalize)
|
BatchedPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
|||||||
pc2,
|
pc2,
|
||||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
w1_bias=w1_bias_tri,
|
||||||
|
w2_bias=w2_bias_tri,
|
||||||
|
w1_precision=pc1,
|
||||||
|
w2_precision=pc2,
|
||||||
|
)
|
||||||
|
|
||||||
out_triton_monolithic = triton_kernel_moe_forward(
|
out_triton_monolithic = triton_kernel_moe_forward(
|
||||||
hidden_states=x_tri,
|
hidden_states=x_tri,
|
||||||
w1=w1_tri,
|
w1=w1_tri,
|
||||||
@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
|||||||
gating_output=exp_data_tri,
|
gating_output=exp_data_tri,
|
||||||
topk=topk,
|
topk=topk,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
w1_bias=w1_bias_tri,
|
quant_config=quant_config,
|
||||||
w2_bias=w2_bias_tri,
|
|
||||||
w1_precision=pc1,
|
|
||||||
w2_precision=pc2,
|
|
||||||
)
|
)
|
||||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||||
|
|
||||||
@ -336,6 +341,13 @@ def batched_moe(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
max_num_tokens = round_up(a.shape[0], 64)
|
max_num_tokens = round_up(a.shape[0], 64)
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
w1_precision=w1_precision,
|
||||||
|
w2_precision=w2_precision,
|
||||||
|
w1_bias=w1_bias,
|
||||||
|
w2_bias=w2_bias,
|
||||||
|
)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
BatchedPrepareAndFinalize(
|
BatchedPrepareAndFinalize(
|
||||||
max_num_tokens,
|
max_num_tokens,
|
||||||
@ -344,19 +356,12 @@ def batched_moe(
|
|||||||
rank=0,
|
rank=0,
|
||||||
),
|
),
|
||||||
BatchedOAITritonExperts(
|
BatchedOAITritonExperts(
|
||||||
None,
|
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
w1_precision=w1_precision,
|
quant_config=quant_config,
|
||||||
w2_precision=w2_precision,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
extra_expert_args = {
|
|
||||||
"w1_bias": w1_bias,
|
|
||||||
"w2_bias": w2_bias,
|
|
||||||
}
|
|
||||||
|
|
||||||
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
|
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@ -365,7 +370,6 @@ def batched_moe(
|
|||||||
w2,
|
w2,
|
||||||
topk_weight,
|
topk_weight,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
extra_expert_args=extra_expert_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
|||||||
run_modular_kernel)
|
run_modular_kernel)
|
||||||
from .modular_kernel_tools.mk_objects import (
|
from .modular_kernel_tools.mk_objects import (
|
||||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
|
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
|
||||||
|
expert_info)
|
||||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||||
parallel_launch_with_config)
|
parallel_launch_with_config)
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ def rank_worker(
|
|||||||
pgi: ProcessGroupInfo,
|
pgi: ProcessGroupInfo,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
cpu_group,
|
cpu_group,
|
||||||
config: Config,
|
base_config: Config,
|
||||||
weights: WeightTensors,
|
weights: WeightTensors,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
):
|
):
|
||||||
@ -63,42 +63,44 @@ def rank_worker(
|
|||||||
|
|
||||||
# sanity check
|
# sanity check
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
if config.fused_moe_chunk_size is not None:
|
if base_config.fused_moe_chunk_size is not None:
|
||||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
assert (
|
||||||
|
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||||
|
|
||||||
# get weights to this device
|
# get weights to this device
|
||||||
weights.to_current_device()
|
weights.to_current_device()
|
||||||
|
|
||||||
Ms = config.Ms
|
Ms = base_config.Ms
|
||||||
assert isinstance(Ms, list)
|
assert isinstance(Ms, list)
|
||||||
TOPKs = config.topks
|
TOPKs = base_config.topks
|
||||||
assert isinstance(TOPKs, list)
|
assert isinstance(TOPKs, list)
|
||||||
|
|
||||||
exceptions = []
|
exceptions = []
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
for m, topk in product(Ms, TOPKs):
|
for m, topk in product(Ms, TOPKs):
|
||||||
|
# override m and topk
|
||||||
|
config = copy.deepcopy(base_config)
|
||||||
|
config.Ms = m
|
||||||
|
config.topks = topk
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||||
count = count + 1
|
count = count + 1
|
||||||
# override m and topk
|
|
||||||
cfgx = copy.deepcopy(config)
|
|
||||||
cfgx.Ms = m
|
|
||||||
cfgx.topks = topk
|
|
||||||
|
|
||||||
# inputs for rank
|
# inputs for rank
|
||||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
rank_tensors = RankTensors.make(config, pgi)
|
||||||
|
|
||||||
# modular kernel out
|
# modular kernel out
|
||||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
|
||||||
rank_tensors)
|
rank_tensors)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
ref_out = reference_moe_impl(config, weights, rank_tensors)
|
||||||
|
|
||||||
if config.quant_dtype == "nvfp4":
|
if config.quant_dtype == "nvfp4":
|
||||||
atol = 1e-1
|
atol = 1e-1 if config.K < 4096 else 2e-1
|
||||||
rtol = 1e-1
|
rtol = 1e-1 if config.K < 4096 else 2e-1
|
||||||
else:
|
else:
|
||||||
atol = 3e-2
|
atol = 3e-2
|
||||||
rtol = 3e-2
|
rtol = 3e-2
|
||||||
@ -132,7 +134,7 @@ Ms = [32, 64]
|
|||||||
# hidden sizes, making this too large will cause fp4 tests to fail.
|
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||||
# Also needs to be a multiple of 1024 for deep_gemm.
|
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||||
Ks = [2048]
|
Ks = [2048]
|
||||||
Ns = [2048]
|
Ns = [1024]
|
||||||
TOPKs = [4, 1]
|
TOPKs = [4, 1]
|
||||||
Es = [32]
|
Es = [32]
|
||||||
DTYPEs = [torch.bfloat16]
|
DTYPEs = [torch.bfloat16]
|
||||||
@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool:
|
|||||||
@meets_multi_gpu_requirements
|
@meets_multi_gpu_requirements
|
||||||
def test_modular_kernel_combinations_multigpu(
|
def test_modular_kernel_combinations_multigpu(
|
||||||
k: int, n: int, e: int, dtype: torch.dtype,
|
k: int, n: int, e: int, dtype: torch.dtype,
|
||||||
quant_config: Optional[FusedMoEQuantConfig],
|
quant_config: Optional[TestMoEQuantConfig],
|
||||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||||
mk.FusedMoEPermuteExpertsUnpermute],
|
mk.FusedMoEPermuteExpertsUnpermute],
|
||||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||||
@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu(
|
|||||||
@pytest.mark.parametrize("world_size", [1])
|
@pytest.mark.parametrize("world_size", [1])
|
||||||
def test_modular_kernel_combinations_singlegpu(
|
def test_modular_kernel_combinations_singlegpu(
|
||||||
k: int, n: int, e: int, dtype: torch.dtype,
|
k: int, n: int, e: int, dtype: torch.dtype,
|
||||||
quant_config: Optional[FusedMoEQuantConfig],
|
quant_config: Optional[TestMoEQuantConfig],
|
||||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||||
mk.FusedMoEPermuteExpertsUnpermute],
|
mk.FusedMoEPermuteExpertsUnpermute],
|
||||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||||
|
|||||||
@ -15,11 +15,14 @@ from transformers import MixtralConfig
|
|||||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
|
from tests.kernels.moe.utils import fused_moe
|
||||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import init_distributed_environment
|
from vllm.distributed.parallel_state import init_distributed_environment
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
|
||||||
|
int8_w8a16_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_topk, modular_triton_fused_moe)
|
fused_topk, modular_triton_fused_moe)
|
||||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||||
@ -187,14 +190,9 @@ def test_fused_moe(
|
|||||||
#
|
#
|
||||||
# Setup test functions
|
# Setup test functions
|
||||||
#
|
#
|
||||||
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
|
|
||||||
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
|
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
|
||||||
use_int8_w8a8=False,
|
|
||||||
use_int8_w8a16=False,
|
|
||||||
use_int4_w4a16=False,
|
|
||||||
use_mxfp4_w4a4=False,
|
|
||||||
per_act_token_quant=False,
|
|
||||||
block_shape=None)
|
|
||||||
|
|
||||||
def m_fused_moe(
|
def m_fused_moe(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
else:
|
else:
|
||||||
e_map = None
|
e_map = None
|
||||||
|
|
||||||
|
if weight_bits == 4:
|
||||||
|
quant_config_builder = int4_w4a16_moe_quant_config
|
||||||
|
else:
|
||||||
|
assert weight_bits == 8
|
||||||
|
quant_config_builder = int8_w8a16_moe_quant_config
|
||||||
|
|
||||||
|
quant_config = quant_config_builder(w1_scale=w1_scales,
|
||||||
|
w2_scale=w2_scales,
|
||||||
|
w1_zp=w1_qzeros if has_zp else None,
|
||||||
|
w2_zp=w2_qzeros if has_zp else None,
|
||||||
|
block_shape=[0, group_size])
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
triton_output = fused_moe(a,
|
triton_output = fused_moe(a,
|
||||||
w1_qweight,
|
w1_qweight,
|
||||||
@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
score,
|
score,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
|
||||||
use_int8_w8a16=weight_bits == 8,
|
|
||||||
global_num_experts=e,
|
global_num_experts=e,
|
||||||
expert_map=e_map,
|
expert_map=e_map,
|
||||||
w1_scale=w1_scales,
|
quant_config=quant_config)
|
||||||
w2_scale=w2_scales,
|
|
||||||
w1_zp=w1_qzeros if has_zp else None,
|
|
||||||
w2_zp=w2_qzeros if has_zp else None,
|
|
||||||
block_shape=[0, group_size])
|
|
||||||
torch_output = torch_moe(a,
|
torch_output = torch_moe(a,
|
||||||
w1_ref,
|
w1_ref,
|
||||||
w2_ref,
|
w2_ref,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
|||||||
from tests.kernels.utils import torch_moe
|
from tests.kernels.utils import torch_moe
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
in_dtype=dtype,
|
in_dtype=dtype,
|
||||||
quant_dtype="nvfp4",
|
quant_dtype="nvfp4",
|
||||||
block_shape=None, # use quant_blocksize?
|
block_shape=None, # use quant_blocksize?
|
||||||
per_act_token_quant=False,
|
per_out_ch_quant=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
assert w1_blockscale is not None
|
assert w1_blockscale is not None
|
||||||
assert w2_blockscale is not None
|
assert w2_blockscale is not None
|
||||||
|
|
||||||
|
quant_config = nvfp4_moe_quant_config(
|
||||||
|
g1_alphas=(1 / w1_gs),
|
||||||
|
g2_alphas=(1 / w2_gs),
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
)
|
||||||
|
|
||||||
cutlass_output = cutlass_moe_fp4(
|
cutlass_output = cutlass_moe_fp4(
|
||||||
a=a,
|
a=a,
|
||||||
a1_gscale=a1_gs,
|
|
||||||
w1_fp4=w1_q,
|
w1_fp4=w1_q,
|
||||||
w1_blockscale=w1_blockscale,
|
|
||||||
g1_alphas=(1 / w1_gs),
|
|
||||||
a2_gscale=a2_gs,
|
|
||||||
w2_fp4=w2_q,
|
w2_fp4=w2_q,
|
||||||
w2_blockscale=w2_blockscale,
|
|
||||||
g2_alphas=(1 / w2_gs),
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
m=m,
|
m=m,
|
||||||
n=n,
|
n=n,
|
||||||
k=k,
|
k=k,
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import torch
|
|||||||
from tests.kernels.utils import torch_experts
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
CutlassBatchedExpertsFp8)
|
CutlassBatchedExpertsFp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
@ -143,10 +145,16 @@ def pplx_cutlass_moe(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
|
||||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
experts = CutlassBatchedExpertsFp8(
|
||||||
out_dtype, per_act_token, per_out_ch,
|
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
|
||||||
ab_strides1, ab_strides2, c_strides1,
|
ab_strides2, c_strides1, c_strides2,
|
||||||
c_strides2)
|
fp8_w8a8_moe_quant_config(
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||||
|
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||||
|
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||||
|
if per_act_token else a1_scale[rank]))
|
||||||
|
|
||||||
fused_cutlass_experts = FusedMoEModularKernel(
|
fused_cutlass_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
@ -167,10 +175,7 @@ def pplx_cutlass_moe(
|
|||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
global_num_experts=num_experts,
|
global_num_experts=num_experts,
|
||||||
expert_map=None, #TODO
|
expert_map=None, #TODO
|
||||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
)
|
||||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
|
||||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
|
||||||
if per_act_token else a1_scale[rank])
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|||||||
@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
PPLX_COMBOS = [
|
PPLX_COMBOS = [
|
||||||
# TODO: figure out why this fails, seems to be test problem
|
# TODO(bnell): figure out why this fails, seems to be test problem
|
||||||
#(1, 128, 128),
|
#(1, 128, 128),
|
||||||
(2, 128, 512),
|
(2, 128, 512),
|
||||||
(3, 1024, 2048),
|
(3, 1024, 2048),
|
||||||
@ -360,18 +360,18 @@ def pplx_prepare_finalize(
|
|||||||
|
|
||||||
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
||||||
a_chunk,
|
a_chunk,
|
||||||
a1_scale,
|
|
||||||
a2_scale,
|
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
None,
|
None,
|
||||||
False,
|
False,
|
||||||
FusedMoEQuantConfig(
|
FusedMoEQuantConfig.make(
|
||||||
quant_dtype,
|
quant_dtype,
|
||||||
per_act_token_quant,
|
per_act_token_quant=per_act_token_quant,
|
||||||
False,
|
per_out_ch_quant=False,
|
||||||
block_shape,
|
block_shape=block_shape,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -540,20 +540,6 @@ def pplx_moe(
|
|||||||
|
|
||||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||||
|
|
||||||
experts = BatchedTritonExperts(
|
|
||||||
max_num_tokens=max_num_tokens,
|
|
||||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
|
||||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
|
||||||
block_shape=block_shape,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
)
|
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
|
||||||
prepare_finalize,
|
|
||||||
experts,
|
|
||||||
shared_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||||
a_chunk = chunk_by_rank(a, rank, world_size)
|
a_chunk = chunk_by_rank(a, rank, world_size)
|
||||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
|
||||||
@ -567,6 +553,28 @@ def pplx_moe(
|
|||||||
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
|
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
|
||||||
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
|
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype,
|
||||||
|
block_shape=block_shape,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
w1_scale=w1_scale_chunk,
|
||||||
|
w2_scale=w2_scale_chunk,
|
||||||
|
a1_scale=a1_scale_chunk,
|
||||||
|
a2_scale=a2_scale_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
experts = BatchedTritonExperts(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_experts = FusedMoEModularKernel(
|
||||||
|
prepare_finalize,
|
||||||
|
experts,
|
||||||
|
shared_experts,
|
||||||
|
)
|
||||||
|
|
||||||
# Note: for now use_compile will error out if the problem size is
|
# Note: for now use_compile will error out if the problem size is
|
||||||
# large enough to trigger chunking. I'm leaving the flag and
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
# setup code in case we are able to revisit this later.
|
# setup code in case we are able to revisit this later.
|
||||||
@ -585,10 +593,6 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
w1_scale=w1_scale_chunk,
|
|
||||||
w2_scale=w2_scale_chunk,
|
|
||||||
a1_scale=a1_scale_chunk,
|
|
||||||
a2_scale=a2_scale_chunk,
|
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
if use_cudagraphs:
|
if use_cudagraphs:
|
||||||
@ -605,10 +609,6 @@ def pplx_moe(
|
|||||||
w2_chunk,
|
w2_chunk,
|
||||||
chunk_topk_weight,
|
chunk_topk_weight,
|
||||||
chunk_topk_ids,
|
chunk_topk_ids,
|
||||||
w1_scale=w1_scale_chunk,
|
|
||||||
w2_scale=w2_scale_chunk,
|
|
||||||
a1_scale=a1_scale_chunk,
|
|
||||||
a2_scale=a2_scale_chunk,
|
|
||||||
global_num_experts=num_experts)
|
global_num_experts=num_experts)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -820,7 +820,7 @@ def test_pplx_moe_slow(
|
|||||||
k,
|
k,
|
||||||
quant_dtype=quant_dtype,
|
quant_dtype=quant_dtype,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
per_act_token_quant=per_act_token_quant,
|
per_out_ch_quant=per_act_token_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
||||||
@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
|||||||
k,
|
k,
|
||||||
quant_dtype=quant_dtype,
|
quant_dtype=quant_dtype,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
per_act_token_quant=per_act_token_quant,
|
per_out_ch_quant=per_act_token_quant,
|
||||||
)
|
)
|
||||||
args["w1"] = w1
|
args["w1"] = w1
|
||||||
args["w2"] = w2
|
args["w2"] = w2
|
||||||
|
|||||||
@ -7,10 +7,12 @@ import itertools
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import fused_moe
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if current_platform.get_device_capability() < (9, 0):
|
if current_platform.get_device_capability() < (9, 0):
|
||||||
@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
|||||||
score,
|
score,
|
||||||
topk,
|
topk,
|
||||||
renormalize=False,
|
renormalize=False,
|
||||||
use_fp8_w8a8=True, # using fp8
|
quant_config=fp8_w8a8_moe_quant_config(
|
||||||
per_channel_quant=True,
|
per_act_token_quant=True,
|
||||||
w1_scale=w1_s,
|
w1_scale=w1_s,
|
||||||
w2_scale=w2_s,
|
w2_scale=w2_s,
|
||||||
block_shape=None, # Not using block quantization
|
block_shape=None, # Not using block quantization
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
|
|||||||
@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
|||||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
FLOAT8_E4M3_MAX)
|
FLOAT8_E4M3_MAX)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
@ -34,18 +35,22 @@ def triton_moe(
|
|||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
return fused_experts(a,
|
return fused_experts(a,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_weight,
|
topk_weight,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
w1_scale=w1_scale,
|
quant_config=quant_config)
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
per_channel_quant=per_act_token_quant,
|
|
||||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
|
||||||
block_shape=block_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def batched_moe(
|
def batched_moe(
|
||||||
@ -64,6 +69,16 @@ def batched_moe(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
max_num_tokens = round_up(a.shape[0], 64)
|
max_num_tokens = round_up(a.shape[0], 64)
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
BatchedPrepareAndFinalize(max_num_tokens,
|
BatchedPrepareAndFinalize(max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
@ -72,21 +87,11 @@ def batched_moe(
|
|||||||
BatchedTritonExperts(
|
BatchedTritonExperts(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
quant_config=quant_config,
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(a,
|
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weight,
|
|
||||||
topk_ids,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale)
|
|
||||||
|
|
||||||
|
|
||||||
def naive_batched_moe(
|
def naive_batched_moe(
|
||||||
@ -105,6 +110,16 @@ def naive_batched_moe(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
max_num_tokens = round_up(a.shape[0], 64)
|
max_num_tokens = round_up(a.shape[0], 64)
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
fused_experts = FusedMoEModularKernel(
|
fused_experts = FusedMoEModularKernel(
|
||||||
BatchedPrepareAndFinalize(max_num_tokens,
|
BatchedPrepareAndFinalize(max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
@ -113,21 +128,11 @@ def naive_batched_moe(
|
|||||||
NaiveBatchedExperts(
|
NaiveBatchedExperts(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=1,
|
num_dispatchers=1,
|
||||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
quant_config=quant_config,
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(a,
|
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weight,
|
|
||||||
topk_ids,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale)
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||||
@ -216,7 +221,7 @@ def make_test_weight(
|
|||||||
in_dtype: torch.dtype = torch.bfloat16,
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
per_act_token_quant: bool = False,
|
per_out_ch_quant: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||||
@ -228,7 +233,7 @@ def make_test_weight(
|
|||||||
w_gs_l = [None] * e
|
w_gs_l = [None] * e
|
||||||
for idx in range(e):
|
for idx in range(e):
|
||||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
|
||||||
|
|
||||||
w = torch.stack(w_l)
|
w = torch.stack(w_l)
|
||||||
w_s = torch.stack(w_s_l)
|
w_s = torch.stack(w_s_l)
|
||||||
@ -258,16 +263,16 @@ def make_test_weights(
|
|||||||
in_dtype: torch.dtype = torch.bfloat16,
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
per_act_token_quant: bool = False,
|
per_out_ch_quant: bool = False,
|
||||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]],
|
Optional[torch.Tensor]],
|
||||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]]:
|
Optional[torch.Tensor]]]:
|
||||||
return (
|
return (
|
||||||
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||||
per_act_token_quant),
|
per_out_ch_quant),
|
||||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||||
per_act_token_quant),
|
per_out_ch_quant),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
|
|||||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_quant_config(
|
||||||
|
e: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
in_dtype: torch.dtype,
|
||||||
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||||
|
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||||
|
e,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
in_dtype,
|
||||||
|
quant_dtype,
|
||||||
|
per_out_ch_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hacky/trivial scales for nvfp4.
|
||||||
|
a1_gscale: Optional[torch.Tensor] = None
|
||||||
|
a2_gscale: Optional[torch.Tensor] = None
|
||||||
|
if quant_dtype == "nvfp4":
|
||||||
|
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||||
|
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||||
|
a1_scale = a1_gscale
|
||||||
|
a2_scale = a2_gscale
|
||||||
|
else:
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
|
||||||
|
return w1, w2, FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
block_shape=block_shape,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
a1_gscale=a1_gscale,
|
||||||
|
a2_gscale=a2_gscale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
# TODO: make sure this is handled properly
|
||||||
|
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||||
|
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
score: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool = False,
|
||||||
|
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
|
||||||
|
renormalize)
|
||||||
|
return fused_experts(hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
|
||||||
# CustomOp?
|
# CustomOp?
|
||||||
class BaselineMM(torch.nn.Module):
|
class BaselineMM(torch.nn.Module):
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||||
per_token_quant_int8)
|
per_token_quant_int8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
|||||||
return C.reshape(origin_C_shape).to(output_dtype)
|
return C.reshape(origin_C_shape).to(output_dtype)
|
||||||
|
|
||||||
|
|
||||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
|
||||||
|
topk_ids):
|
||||||
"""This function performs fused moe with per-column int8 quantization
|
"""This function performs fused moe with per-column int8 quantization
|
||||||
using native torch."""
|
using native torch."""
|
||||||
|
|
||||||
@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
|||||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||||
|
|
||||||
# Calculate routing
|
# Calculate routing
|
||||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
|
||||||
topk_weight = topk_weight.view(-1)
|
topk_weight = topk_weight.view(-1)
|
||||||
topk_ids = topk_ids.view(-1)
|
topk_ids = topk_ids.view(-1)
|
||||||
# Process each expert
|
# Process each expert
|
||||||
@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
|||||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
|
topk_weights, topk_ids = torch.topk(score, topk)
|
||||||
|
|
||||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
|
||||||
out = fused_moe(
|
topk_weights, topk_ids)
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
torch.int8,
|
||||||
|
per_act_token_quant=True,
|
||||||
|
block_shape=None,
|
||||||
|
w1_scale=w1_s,
|
||||||
|
w2_scale=w2_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = fused_experts(
|
||||||
a,
|
a,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
score,
|
topk_weights,
|
||||||
topk,
|
topk_ids,
|
||||||
renormalize=False,
|
quant_config=quant_config,
|
||||||
use_int8_w8a8=True, # Using int8-w8a8
|
|
||||||
per_channel_quant=True,
|
|
||||||
w1_scale=w1_s,
|
|
||||||
w2_scale=w2_s,
|
|
||||||
block_shape=None, # Not using block quantization
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check results
|
# Check results
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||||
FusedMoEPrepareAndFinalize)
|
FusedMoEPrepareAndFinalize)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
|
|
||||||
_config: Optional[dict[str, Any]] = None
|
_config: Optional[dict[str, Any]] = None
|
||||||
@ -36,6 +37,7 @@ __all__ = [
|
|||||||
"FusedMoEPermuteExpertsUnpermute",
|
"FusedMoEPermuteExpertsUnpermute",
|
||||||
"FusedMoEActivationFormat",
|
"FusedMoEActivationFormat",
|
||||||
"FusedMoEPrepareAndFinalize",
|
"FusedMoEPrepareAndFinalize",
|
||||||
|
"activation_without_mul",
|
||||||
"override_config",
|
"override_config",
|
||||||
"get_config",
|
"get_config",
|
||||||
]
|
]
|
||||||
@ -43,7 +45,6 @@ __all__ = [
|
|||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
# import to register the custom ops
|
# import to register the custom ops
|
||||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||||
@ -56,13 +57,12 @@ if HAS_TRITON:
|
|||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedTritonExperts)
|
BatchedTritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
TritonExperts, fused_experts, fused_topk, get_config_file_name,
|
||||||
get_config_file_name, grouped_topk)
|
grouped_topk)
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
TritonOrDeepGemmExperts)
|
TritonOrDeepGemmExperts)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
"fused_moe",
|
|
||||||
"fused_topk",
|
"fused_topk",
|
||||||
"fused_experts",
|
"fused_experts",
|
||||||
"get_config_file_name",
|
"get_config_file_name",
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||||
|
deep_gemm_block_shape)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||||
@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
|||||||
|
|
||||||
|
|
||||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
# The Deep Gemm kernels only support block size of 128
|
|
||||||
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
max_num_tokens: int,
|
self,
|
||||||
num_dispatchers: int,
|
max_num_tokens: int,
|
||||||
block_shape: list[int],
|
num_dispatchers: int,
|
||||||
per_act_token_quant=False):
|
quant_config: FusedMoEQuantConfig,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
max_num_tokens: Maximum number of tokens from a DP Rank
|
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||||
num_dispatchers: The number of DP dispatchers.
|
num_dispatchers: The number of DP dispatchers.
|
||||||
block_shape: Block quantization block shape.
|
quant_config: Quantization configuration
|
||||||
per_act_token_quant: Per activation token quantization flag.
|
|
||||||
"""
|
"""
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig(
|
assert self.block_shape == deep_gemm_block_shape()
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
|
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# for the M expectation of each batch, correctly setting this value
|
# for the M expectation of each batch, correctly setting this value
|
||||||
# may lead to better performance.
|
# may lead to better performance.
|
||||||
expected_m = max_num_tokens
|
expected_m = max_num_tokens
|
||||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
|
||||||
workspace1, expert_num_tokens, expected_m)
|
workspace1, expert_num_tokens, expected_m)
|
||||||
|
|
||||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||||
workspace1, expert_num_tokens)
|
workspace1, expert_num_tokens)
|
||||||
|
|
||||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
|
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
|
||||||
expert_num_tokens, expected_m)
|
output, expert_num_tokens, expected_m)
|
||||||
|
|||||||
@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||||
|
deep_gemm_block_shape)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedTritonExperts)
|
BatchedTritonExperts)
|
||||||
|
|
||||||
|
|
||||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
max_num_tokens: int,
|
self,
|
||||||
num_dispatchers: int,
|
max_num_tokens: int,
|
||||||
use_fp8_w8a8: bool = False,
|
num_dispatchers: int,
|
||||||
use_int8_w8a8: bool = False,
|
quant_config: FusedMoEQuantConfig,
|
||||||
use_int8_w8a16: bool = False,
|
allow_deep_gemm: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
):
|
||||||
block_shape: Optional[list[int]] = None,
|
super().__init__(quant_config)
|
||||||
per_act_token_quant: bool = False,
|
|
||||||
allow_deep_gemm: bool = False):
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
|
||||||
assert not use_int8_w8a16, "NYI"
|
|
||||||
assert not use_int4_w4a16, "NYI"
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
FusedMoEQuantConfig.make(
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
block_shape=block_shape,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.batched_triton_experts = BatchedTritonExperts(
|
self.batched_triton_experts = BatchedTritonExperts(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=num_dispatchers,
|
num_dispatchers=num_dispatchers,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
quant_config=self.quant_config,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
per_act_token_quant=self.per_act_token_quant,
|
|
||||||
block_shape=self.block_shape,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8
|
self.allow_deep_gemm = (allow_deep_gemm
|
||||||
and self.block_shape
|
and self.quant_config.use_fp8_w8a8 and
|
||||||
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
|
self.block_shape == deep_gemm_block_shape())
|
||||||
|
|
||||||
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
num_dispatchers=num_dispatchers,
|
num_dispatchers=num_dispatchers,
|
||||||
block_shape=self.block_shape, # type: ignore[arg-type]
|
quant_config=self.quant_config,
|
||||||
) if self.allow_deep_gemm else None
|
) if self.allow_deep_gemm else None
|
||||||
|
|
||||||
assert (self.batched_deep_gemm_experts is not None
|
assert (self.batched_deep_gemm_experts is not None
|
||||||
@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||||
assert experts is not None
|
assert experts is not None
|
||||||
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
||||||
activation, global_num_experts, expert_map, w1_scale,
|
activation, global_num_experts, expert_map, a1q_scale,
|
||||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
workspace13, workspace2, expert_tokens_meta,
|
||||||
workspace2, expert_tokens_meta,
|
|
||||||
apply_router_weight_on_input)
|
apply_router_weight_on_input)
|
||||||
|
|||||||
@ -1,103 +1,322 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import (QuantizationArgs,
|
|
||||||
QuantizationStrategy,
|
|
||||||
QuantizationType)
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantizationConfig)
|
GroupShape)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv, has_triton_kernels
|
||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
|
if TYPE_CHECKING and has_triton_kernels:
|
||||||
|
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_quant_config_quantization_args(
|
def _get_config_dtype_str(
|
||||||
quant_config: Optional[QuantizationConfig],
|
dtype: torch.dtype,
|
||||||
prop_name: str,
|
use_fp8_w8a8: bool = False,
|
||||||
) -> Optional[QuantizationArgs]:
|
use_int8_w8a16: bool = False,
|
||||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
use_int4_w4a16: bool = False,
|
||||||
and "Linear" in quant_config.target_scheme_map and
|
use_mxfp4_w4a4: bool = False,
|
||||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
) -> Optional[str]:
|
||||||
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
"""
|
||||||
else:
|
Return a string used to construct the filename that contains the
|
||||||
return None
|
tuning info for a particular quantization scheme. See
|
||||||
|
try_get_optimal_moe_config in fused_moe.py.
|
||||||
|
"""
|
||||||
def get_quant_config_input_quant(
|
|
||||||
quant_config: Optional[QuantizationConfig]
|
|
||||||
) -> Optional[QuantizationArgs]:
|
|
||||||
return _get_quant_config_quantization_args(quant_config,
|
|
||||||
"input_activations")
|
|
||||||
|
|
||||||
|
|
||||||
def get_quant_config_weight_quant(
|
|
||||||
quant_config: Optional[QuantizationConfig]
|
|
||||||
) -> Optional[QuantizationArgs]:
|
|
||||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_quant_dtype(
|
|
||||||
use_fp8_w8a8: bool,
|
|
||||||
use_int8_w8a8: bool,
|
|
||||||
use_int8_w8a16: bool,
|
|
||||||
use_int4_w4a16: bool,
|
|
||||||
use_mxfp4_w4a4: bool,
|
|
||||||
) -> Union[None, torch.dtype, str]:
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
return torch.float8_e4m3fn
|
return "fp8_w8a8"
|
||||||
elif use_int8_w8a8:
|
elif use_int8_w8a16:
|
||||||
return torch.int8
|
return "int8_w8a16"
|
||||||
|
elif use_int4_w4a16:
|
||||||
|
return "int4_w4a16"
|
||||||
elif use_mxfp4_w4a4:
|
elif use_mxfp4_w4a4:
|
||||||
return "mxfp4"
|
return "mxfp4_w4a4"
|
||||||
|
elif dtype == torch.float:
|
||||||
|
# avoiding cases where kernel fails when float32 MoE
|
||||||
|
# use fp16/bfloat16 configs
|
||||||
|
return "float32"
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_flags_to_group_shape(
|
||||||
|
quant_dtype: Union[torch.dtype, str, None],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
per_out_ch_quant: bool,
|
||||||
|
block_shape: Optional[list[int]],
|
||||||
|
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
|
||||||
|
"""
|
||||||
|
Convert MoE quantization flags into more generic GroupShapes.
|
||||||
|
"""
|
||||||
|
a_shape: Optional[GroupShape]
|
||||||
|
w_shape: Optional[GroupShape]
|
||||||
|
if block_shape is not None:
|
||||||
|
assert not per_act_token_quant
|
||||||
|
assert not per_out_ch_quant
|
||||||
|
# TODO(bnell): this is not quite right for activations since first
|
||||||
|
# dim should be 1.
|
||||||
|
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||||
|
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||||
|
else:
|
||||||
|
w_shape = None
|
||||||
|
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
|
||||||
|
|
||||||
|
if per_act_token_quant:
|
||||||
|
a_shape = GroupShape.PER_TOKEN
|
||||||
|
|
||||||
|
if per_out_ch_quant:
|
||||||
|
w_shape = GroupShape.PER_TOKEN
|
||||||
|
|
||||||
|
return a_shape, w_shape
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FusedMoEQuantDesc:
|
||||||
|
"""
|
||||||
|
A quantization descriptor for fused MoE ops. This class can describe
|
||||||
|
either activations or weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The quantized type of this parameters. None means unquantized or
|
||||||
|
# already quantized.
|
||||||
|
# TODO (bnell): use scalar_type instead of Union.
|
||||||
|
dtype: Union[torch.dtype, str, None] = None
|
||||||
|
|
||||||
|
# A field that describes the quantization group shape, from quant_utils.py.
|
||||||
|
# * (-1, -1) for per-tensor quantization
|
||||||
|
# * (1, -1) for per-row quantization
|
||||||
|
# * (-1, 1) for per-column quantization
|
||||||
|
# * (128, 128) for 128x128 deepseek style block quantization
|
||||||
|
# * (1, 128) for deepseek style activation quantization
|
||||||
|
# (i.e. per-token-per-group)
|
||||||
|
shape: Optional[GroupShape] = None
|
||||||
|
|
||||||
|
# Quantization scales.
|
||||||
|
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
|
||||||
|
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
|
||||||
|
|
||||||
|
# Quantization alphas or gscales, used for nvfp4 types.
|
||||||
|
# TODO(bnell): put some of these in subclasses
|
||||||
|
alpha_or_gscale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Zero points for int4/int8 types
|
||||||
|
zp: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Biases for GPT triton MoE
|
||||||
|
bias: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(bnell): have subclasses for specific moe methods?
|
||||||
|
# e.g. for specific arguments bias, precision, etc.
|
||||||
@dataclass
|
@dataclass
|
||||||
class FusedMoEQuantConfig:
|
class FusedMoEQuantConfig:
|
||||||
# The post quantization activation type.
|
"""
|
||||||
# TODO (bnell): use scalar_type instead of Union.
|
The FusedMoEQuantConfig contains all the quantization parameters for
|
||||||
quant_dtype: Union[torch.dtype, str, None] = None
|
a single FusedMoEMethodBase operation. It consists of four
|
||||||
per_act_token_quant: bool = False
|
FusedMoEQuantDescs, one for each activation and set of weights.
|
||||||
per_out_ch_quant: bool = False
|
|
||||||
block_shape: Optional[list[int]] = None
|
|
||||||
|
|
||||||
# TODO: add col major flag?
|
Each FusedMoEMethodBase must implement a get_fused_moe_quant_config
|
||||||
# add detailed quant info for input, intermediates, weights, etc?
|
method to construct a FusedMoEQuantConfig for use with that class.
|
||||||
|
|
||||||
|
FusedMoEQuant configs are only used for modular kernels, fused_experts
|
||||||
|
(from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and
|
||||||
|
triton_kernel_moe_forward. Other MoE methods can ignore the
|
||||||
|
FusedMoEQuantConfig (for now) and hardcode it to None.
|
||||||
|
|
||||||
|
There are currently some restrictions on what can be expressed:
|
||||||
|
- Most MoE ops only support similar quantization strategies for
|
||||||
|
each parameter, e.g. both weights must have the same GroupShape
|
||||||
|
and both activations must share the same GroupShape. One exception to
|
||||||
|
this is the cutlass moe which allows per channel quantization on the
|
||||||
|
outputs. Note: this restrictions are not always rigorously checked.
|
||||||
|
- Not all fused MoE functions support all the parameters, e.g. zero points,
|
||||||
|
global scales, alphas and biases are not universally supported.
|
||||||
|
- Fully general GroupShapes are not allowed. Activations only support
|
||||||
|
per token, per tensor or K-blocked.
|
||||||
|
- Weights are not required to have a GroupShape since they have already
|
||||||
|
been quantized.
|
||||||
|
|
||||||
|
Other notes:
|
||||||
|
- PrecisionConfigs are specific to GPT OSS Triton.
|
||||||
|
- As a follow up it would probably make sense to subclass FusedMoEQuantDesc
|
||||||
|
or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses
|
||||||
|
so that only the required quantization parameters are used/stored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking
|
||||||
|
_a1: FusedMoEQuantDesc
|
||||||
|
_a2: FusedMoEQuantDesc
|
||||||
|
_w1: FusedMoEQuantDesc
|
||||||
|
_w2: FusedMoEQuantDesc
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert (not self.per_act_token_quant
|
assert (not self.per_act_token_quant
|
||||||
or self.block_shape is None), "illegal quantization"
|
or self.block_shape is None), "illegal quantization"
|
||||||
|
|
||||||
|
#
|
||||||
|
# Convenience accessors for various properties.
|
||||||
|
#
|
||||||
|
|
||||||
|
@property
|
||||||
|
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||||
|
return self._a1.dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_quantized(self) -> bool:
|
def is_quantized(self) -> bool:
|
||||||
return self.quant_dtype is not None
|
return self.quant_dtype is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_per_act_token(self) -> bool:
|
def is_per_act_token(self) -> bool:
|
||||||
return self.per_act_token_quant
|
return self._a1.shape == GroupShape.PER_TOKEN
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_act_token_quant(self) -> bool:
|
||||||
|
return self._a1.shape == GroupShape.PER_TOKEN
|
||||||
|
|
||||||
|
@property
|
||||||
|
def per_out_ch_quant(self) -> bool:
|
||||||
|
return self._w1.shape == GroupShape.PER_TOKEN
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_per_tensor(self) -> bool:
|
||||||
|
return self._a1.shape == GroupShape.PER_TENSOR
|
||||||
|
|
||||||
|
@property
|
||||||
|
def block_shape(self) -> Optional[list[int]]:
|
||||||
|
if (self._a1.shape is not None
|
||||||
|
and self._a1.shape != GroupShape.PER_TENSOR
|
||||||
|
and self._a1.shape != GroupShape.PER_TOKEN):
|
||||||
|
return [self._a1.shape.row, self._a1.shape.col]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_block_quantized(self) -> bool:
|
def is_block_quantized(self) -> bool:
|
||||||
return self.block_shape is not None
|
return self.block_shape is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_per_tensor(self) -> bool:
|
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||||
return not self.per_act_token_quant and self.block_shape is None
|
assert self._a1.scale is None or isinstance(self._a1.scale,
|
||||||
|
torch.Tensor)
|
||||||
|
return self._a1.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._a1.alpha_or_gscale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
assert self._a2.scale is None or isinstance(self._a2.scale,
|
||||||
|
torch.Tensor)
|
||||||
|
return self._a2.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._a2.alpha_or_gscale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
assert self._w1.scale is None or isinstance(self._w1.scale,
|
||||||
|
torch.Tensor)
|
||||||
|
return self._w1.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._w1.zp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._w1.bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_precision(self) -> Optional["PrecisionConfig"]:
|
||||||
|
assert self._w1.scale is None or isinstance(self._w1.scale,
|
||||||
|
PrecisionConfig)
|
||||||
|
return self._w1.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._w1.alpha_or_gscale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
assert self._w2.scale is None or isinstance(self._w2.scale,
|
||||||
|
torch.Tensor)
|
||||||
|
return self._w2.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._w2.zp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._w2.bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_precision(self) -> Optional["PrecisionConfig"]:
|
||||||
|
assert self._w2.scale is None or isinstance(self._w2.scale,
|
||||||
|
PrecisionConfig)
|
||||||
|
return self._w2.scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._w2.alpha_or_gscale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_fp8_w8a8(self) -> bool:
|
||||||
|
return self.quant_dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_int8_w8a8(self) -> bool:
|
||||||
|
return self.quant_dtype == torch.int8
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_int8_w8a16(self) -> bool:
|
||||||
|
return (self._a1.dtype is None and self._w1.dtype == torch.int8)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_int4_w4a16(self) -> bool:
|
||||||
|
return (self._a1.dtype is None and self._w1.dtype == "int4")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_mxfp4_w4a4(self) -> bool:
|
||||||
|
return self.quant_dtype == "mxfp4"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_nvfp4_w4a4(self) -> bool:
|
||||||
|
return self.quant_dtype == "nvfp4"
|
||||||
|
|
||||||
|
def config_name(self, dtype: torch.dtype) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Return a string used to construct the filename that contains the
|
||||||
|
tuning info for a particular quantization scheme. See
|
||||||
|
try_get_optimal_moe_config in fused_moe.py.
|
||||||
|
"""
|
||||||
|
return _get_config_dtype_str(
|
||||||
|
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=self.use_int8_w8a16,
|
||||||
|
use_int4_w4a16=self.use_int4_w4a16,
|
||||||
|
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def scale_shape(
|
def scale_shape(
|
||||||
self,
|
self,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
) -> Optional[tuple[int, int]]:
|
) -> Optional[tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Construct the proper activation scale shape for this
|
||||||
|
config.
|
||||||
|
"""
|
||||||
if self.is_quantized:
|
if self.is_quantized:
|
||||||
if self.is_block_quantized:
|
if self.is_block_quantized:
|
||||||
assert self.block_shape is not None
|
assert self.block_shape is not None
|
||||||
@ -117,6 +336,10 @@ class FusedMoEQuantConfig:
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
) -> Optional[tuple[int, int, int]]:
|
) -> Optional[tuple[int, int, int]]:
|
||||||
|
"""
|
||||||
|
Construct the proper activation batched scale shape for this
|
||||||
|
config, e.g. (num experts, *scale_shape).
|
||||||
|
"""
|
||||||
if self.is_quantized:
|
if self.is_quantized:
|
||||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||||
assert scale_shape is not None
|
assert scale_shape is not None
|
||||||
@ -126,38 +349,218 @@ class FusedMoEQuantConfig:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(
|
def make(
|
||||||
use_fp8_w8a8: bool = False,
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
use_mxfp4_w4a4: bool = False,
|
|
||||||
per_act_token_quant: bool = False,
|
per_act_token_quant: bool = False,
|
||||||
per_out_ch_quant: bool = False,
|
per_out_ch_quant: bool = False,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
|
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||||
|
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
g1_alphas: Optional[torch.Tensor] = None,
|
||||||
|
g2_alphas: Optional[torch.Tensor] = None,
|
||||||
|
a1_gscale: Optional[torch.Tensor] = None,
|
||||||
|
a2_gscale: Optional[torch.Tensor] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
) -> "FusedMoEQuantConfig":
|
) -> "FusedMoEQuantConfig":
|
||||||
assert sum([
|
"""
|
||||||
int(flag) for flag in [
|
General builder function for a FusedMoEQuantConfig.
|
||||||
use_fp8_w8a8,
|
- quant_dtype: Optional quantization type. None if activations are
|
||||||
use_int8_w8a8,
|
unquantized or quantized prior to calling. Note: "nvfp4" and
|
||||||
use_int8_w8a16,
|
"mxfp4" are the only valid string values for quant_dtype.
|
||||||
use_int4_w4a16,
|
- per_act_token_quant: Activations have per token quantization.
|
||||||
use_mxfp4_w4a4,
|
- per_out_ch_quant: Outputs have per channel quantization. (only
|
||||||
]
|
for cutlass).
|
||||||
]) <= 1, "Quantization flags are mutually exclusive."
|
- block_shape: Optional block size for block-wise quantization.
|
||||||
|
Incompatible with per_act_token and per_out_ch quant.
|
||||||
|
- w1_scale: Optional scale to be used for w1.
|
||||||
|
- w2_scale: Optional scale to be used for w2.
|
||||||
|
- a1_scale: Optional scale to be used for a1.
|
||||||
|
- a2_scale: Optional scale to be used for a2.
|
||||||
|
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
||||||
|
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
||||||
|
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
||||||
|
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
||||||
|
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
||||||
|
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||||
|
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||||
|
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||||
|
"""
|
||||||
|
assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4"
|
||||||
|
or quant_dtype == "mxfp4")
|
||||||
|
a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype,
|
||||||
|
per_act_token_quant,
|
||||||
|
per_out_ch_quant,
|
||||||
|
block_shape)
|
||||||
|
quant_config = FusedMoEQuantConfig(
|
||||||
|
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
|
||||||
|
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
|
||||||
|
_w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas,
|
||||||
|
w1_zp, w1_bias),
|
||||||
|
_w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas,
|
||||||
|
w2_zp, w2_bias),
|
||||||
|
)
|
||||||
|
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||||
|
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||||
|
assert quant_config.block_shape == block_shape
|
||||||
|
return quant_config
|
||||||
|
|
||||||
quant_dtype = get_config_quant_dtype(
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
def fp8_w8a8_moe_quant_config(
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
w1_scale: torch.Tensor,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
w2_scale: torch.Tensor,
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
)
|
per_act_token_quant: bool = False,
|
||||||
return FusedMoEQuantConfig(
|
per_out_ch_quant: bool = False,
|
||||||
quant_dtype,
|
block_shape: Optional[list[int]] = None,
|
||||||
per_act_token_quant,
|
) -> FusedMoEQuantConfig:
|
||||||
per_out_ch_quant,
|
"""
|
||||||
block_shape,
|
Construct a quant config for fp8 activations and fp8 weights.
|
||||||
)
|
"""
|
||||||
|
return FusedMoEQuantConfig.make(torch.float8_e4m3fn,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def int8_w8a8_moe_quant_config(
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
per_act_token_quant: bool = False,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for int8 activations and int8 weights.
|
||||||
|
"""
|
||||||
|
return FusedMoEQuantConfig.make(
|
||||||
|
torch.int8,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token_quant,
|
||||||
|
per_out_ch_quant=False,
|
||||||
|
block_shape=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def mxfp4_w4a4_moe_quant_config(
|
||||||
|
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||||
|
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||||
|
"""
|
||||||
|
return FusedMoEQuantConfig.make(
|
||||||
|
"mxfp4",
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
w1_bias=w1_bias,
|
||||||
|
w2_bias=w2_bias,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
per_out_ch_quant=False,
|
||||||
|
block_shape=block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def nvfp4_moe_quant_config(
|
||||||
|
g1_alphas: torch.Tensor,
|
||||||
|
g2_alphas: torch.Tensor,
|
||||||
|
a1_gscale: torch.Tensor,
|
||||||
|
a2_gscale: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||||
|
"""
|
||||||
|
return FusedMoEQuantConfig.make(
|
||||||
|
"nvfp4",
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_gscale=a1_gscale,
|
||||||
|
a2_gscale=a2_gscale,
|
||||||
|
g1_alphas=g1_alphas,
|
||||||
|
g2_alphas=g2_alphas,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
per_out_ch_quant=False,
|
||||||
|
block_shape=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def int4_w4a16_moe_quant_config(
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for 16-bit float activations and int4 weights.
|
||||||
|
Note: Activations are pre-quantized.
|
||||||
|
"""
|
||||||
|
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||||
|
return FusedMoEQuantConfig(
|
||||||
|
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||||
|
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||||
|
_w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp),
|
||||||
|
_w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def int8_w8a16_moe_quant_config(
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for 16-bit float activations and int8 weights.
|
||||||
|
Note: Activations are pre-quantized.
|
||||||
|
"""
|
||||||
|
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||||
|
return FusedMoEQuantConfig(
|
||||||
|
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||||
|
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||||
|
_w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp),
|
||||||
|
_w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def biased_moe_quant_config(
|
||||||
|
w1_bias: Optional[torch.Tensor],
|
||||||
|
w2_bias: Optional[torch.Tensor],
|
||||||
|
) -> FusedMoEQuantConfig:
|
||||||
|
"""
|
||||||
|
Construct a quant config for unquantized activations with biases.
|
||||||
|
"""
|
||||||
|
return FusedMoEQuantConfig(
|
||||||
|
_a1=FusedMoEQuantDesc(),
|
||||||
|
_a2=FusedMoEQuantDesc(),
|
||||||
|
_w1=FusedMoEQuantDesc(bias=w1_bias),
|
||||||
|
_w2=FusedMoEQuantDesc(bias=w2_bias),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# A FusedMoEQuantConfig constant for an unquantized MoE op.
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -315,8 +718,6 @@ class FusedMoEConfig:
|
|||||||
# The activation type.
|
# The activation type.
|
||||||
in_dtype: torch.dtype
|
in_dtype: torch.dtype
|
||||||
|
|
||||||
quant_config: Optional[FusedMoEQuantConfig] = None
|
|
||||||
|
|
||||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||||
|
|
||||||
has_bias: bool = False
|
has_bias: bool = False
|
||||||
@ -328,34 +729,6 @@ class FusedMoEConfig:
|
|||||||
|
|
||||||
assert self.max_num_tokens > 0
|
assert self.max_num_tokens > 0
|
||||||
|
|
||||||
@property
|
|
||||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
|
||||||
if self.quant_config is not None:
|
|
||||||
return self.quant_config.quant_dtype
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def block_shape(self) -> Optional[list[int]]:
|
|
||||||
if self.quant_config is not None:
|
|
||||||
return self.quant_config.block_shape
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def per_act_token_quant(self) -> bool:
|
|
||||||
if self.quant_config is not None:
|
|
||||||
return self.quant_config.per_act_token_quant
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def per_out_ch_quant(self) -> bool:
|
|
||||||
if self.quant_config is not None:
|
|
||||||
return self.quant_config.per_out_ch_quant
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
@ -401,97 +774,6 @@ class FusedMoEConfig:
|
|||||||
"""
|
"""
|
||||||
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
||||||
"""
|
"""
|
||||||
return (self.quant_config is not None
|
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||||
and self.quant_config.quant_dtype == "nvfp4"
|
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_FP4
|
|
||||||
and has_flashinfer_cutlass_fused_moe()
|
and has_flashinfer_cutlass_fused_moe()
|
||||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def make(
|
|
||||||
num_experts: int,
|
|
||||||
experts_per_token: int,
|
|
||||||
hidden_dim: int,
|
|
||||||
num_local_experts: int,
|
|
||||||
moe_parallel_config: FusedMoEParallelConfig,
|
|
||||||
in_dtype: torch.dtype,
|
|
||||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
|
||||||
quant_config: Optional[Union[FusedMoEQuantConfig,
|
|
||||||
QuantizationConfig]] = None,
|
|
||||||
has_bias: bool = False,
|
|
||||||
) -> "FusedMoEConfig":
|
|
||||||
|
|
||||||
_quant_config: Optional[FusedMoEQuantConfig] = None
|
|
||||||
|
|
||||||
if quant_config is not None and isinstance(quant_config,
|
|
||||||
QuantizationConfig):
|
|
||||||
if hasattr(quant_config, 'weight_block_size'):
|
|
||||||
block_shape = quant_config.weight_block_size
|
|
||||||
else:
|
|
||||||
block_shape = None
|
|
||||||
per_act_token_quant = False
|
|
||||||
per_out_ch_quant = False
|
|
||||||
quant_dtype: Union[torch.dtype, str, None] = None
|
|
||||||
|
|
||||||
input_quant = get_quant_config_input_quant(quant_config)
|
|
||||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
|
||||||
|
|
||||||
if input_quant is not None:
|
|
||||||
per_act_token_quant = (input_quant.strategy
|
|
||||||
== QuantizationStrategy.TOKEN
|
|
||||||
if input_quant is not None else False)
|
|
||||||
|
|
||||||
if input_quant.num_bits == 8:
|
|
||||||
if input_quant.type == QuantizationType.FLOAT:
|
|
||||||
quant_dtype = torch.float8_e4m3fn
|
|
||||||
elif input_quant.type == QuantizationType.INT:
|
|
||||||
quant_dtype = torch.int8
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
||||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
|
||||||
quant_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
|
||||||
Mxfp4Config)
|
|
||||||
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
|
|
||||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
|
||||||
quant_dtype = "mxfp8"
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.modelopt import (
|
|
||||||
ModelOptNvFp4Config)
|
|
||||||
if quant_dtype is None and isinstance(quant_config,
|
|
||||||
ModelOptNvFp4Config):
|
|
||||||
quant_dtype = "nvfp4"
|
|
||||||
|
|
||||||
if weight_quant is not None:
|
|
||||||
per_out_ch_quant = (
|
|
||||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
|
||||||
|
|
||||||
if quant_dtype is not None:
|
|
||||||
_quant_config = FusedMoEQuantConfig(
|
|
||||||
quant_dtype=quant_dtype,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
per_out_ch_quant=per_out_ch_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_quant_config = FusedMoEQuantConfig()
|
|
||||||
if moe_parallel_config.dp_size > 1:
|
|
||||||
logger.warning_once("MoE DP setup unable to determine "
|
|
||||||
"quantization scheme or unsupported "
|
|
||||||
"quantization type. This model will "
|
|
||||||
"not run with DP enabled.")
|
|
||||||
else:
|
|
||||||
_quant_config = quant_config
|
|
||||||
|
|
||||||
return FusedMoEConfig(
|
|
||||||
num_experts=num_experts,
|
|
||||||
experts_per_token=experts_per_token,
|
|
||||||
hidden_dim=hidden_dim,
|
|
||||||
num_local_experts=num_local_experts,
|
|
||||||
moe_parallel_config=moe_parallel_config,
|
|
||||||
in_dtype=in_dtype,
|
|
||||||
quant_config=_quant_config,
|
|
||||||
max_num_tokens=max_num_tokens,
|
|
||||||
has_bias=has_bias,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
|
||||||
per_out_ch_quant: bool,
|
|
||||||
ab_strides1: torch.Tensor,
|
ab_strides1: torch.Tensor,
|
||||||
ab_strides2: torch.Tensor,
|
ab_strides2: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
c_strides2: torch.Tensor,
|
c_strides2: torch.Tensor,
|
||||||
block_shape: Optional[list[int]] = None,
|
quant_config: FusedMoEQuantConfig,
|
||||||
):
|
):
|
||||||
super().__init__(
|
assert quant_config.use_fp8_w8a8
|
||||||
FusedMoEQuantConfig(
|
super().__init__(quant_config)
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
per_out_ch_quant=per_out_ch_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
self.ab_strides1 = ab_strides1
|
self.ab_strides1 = ab_strides1
|
||||||
self.ab_strides2 = ab_strides2
|
self.ab_strides2 = ab_strides2
|
||||||
@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
|
|
||||||
expert_num_tokens = None
|
expert_num_tokens = None
|
||||||
if expert_tokens_meta is not None:
|
if expert_tokens_meta is not None:
|
||||||
@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
in_dtype = hidden_states.dtype
|
in_dtype = hidden_states.dtype
|
||||||
run_cutlass_moe_fp8(
|
run_cutlass_moe_fp8(
|
||||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
|
||||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
|
||||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
self.c_strides1, self.c_strides2, workspace13, workspace2,
|
||||||
|
expert_num_tokens,
|
||||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||||
self.per_act_token_quant, self.per_out_ch_quant,
|
self.per_act_token_quant, self.per_out_ch_quant,
|
||||||
use_batched_format, topk_weights)
|
use_batched_format, topk_weights)
|
||||||
@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
|
||||||
per_out_ch_quant: bool,
|
|
||||||
ab_strides1: torch.Tensor,
|
ab_strides1: torch.Tensor,
|
||||||
ab_strides2: torch.Tensor,
|
ab_strides2: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
c_strides2: torch.Tensor,
|
c_strides2: torch.Tensor,
|
||||||
block_shape: Optional[list[int]] = None,
|
quant_config: FusedMoEQuantConfig,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
out_dtype,
|
out_dtype,
|
||||||
per_act_token_quant,
|
|
||||||
per_out_ch_quant,
|
|
||||||
ab_strides1,
|
ab_strides1,
|
||||||
ab_strides2,
|
ab_strides2,
|
||||||
c_strides1,
|
c_strides1,
|
||||||
c_strides2,
|
c_strides2,
|
||||||
block_shape,
|
quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
|||||||
max_experts_per_worker: int,
|
max_experts_per_worker: int,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
|
||||||
per_out_ch_quant: bool,
|
|
||||||
ab_strides1: torch.Tensor,
|
ab_strides1: torch.Tensor,
|
||||||
ab_strides2: torch.Tensor,
|
ab_strides2: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
c_strides2: torch.Tensor,
|
c_strides2: torch.Tensor,
|
||||||
block_shape: Optional[list[int]] = None,
|
quant_config: FusedMoEQuantConfig,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
out_dtype,
|
out_dtype,
|
||||||
per_act_token_quant,
|
|
||||||
per_out_ch_quant,
|
|
||||||
ab_strides1,
|
ab_strides1,
|
||||||
ab_strides2,
|
ab_strides2,
|
||||||
c_strides1,
|
c_strides1,
|
||||||
c_strides2,
|
c_strides2,
|
||||||
block_shape,
|
quant_config,
|
||||||
)
|
)
|
||||||
assert max_experts_per_worker > 0
|
assert max_experts_per_worker > 0
|
||||||
self.max_experts_per_worker = max_experts_per_worker
|
self.max_experts_per_worker = max_experts_per_worker
|
||||||
@ -414,16 +395,12 @@ def cutlass_moe_fp8(
|
|||||||
w2_q: torch.Tensor,
|
w2_q: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
ab_strides1: torch.Tensor,
|
ab_strides1: torch.Tensor,
|
||||||
ab_strides2: torch.Tensor,
|
ab_strides2: torch.Tensor,
|
||||||
c_strides1: torch.Tensor,
|
c_strides1: torch.Tensor,
|
||||||
c_strides2: torch.Tensor,
|
c_strides2: torch.Tensor,
|
||||||
per_act_token: Optional[bool] = None,
|
quant_config: FusedMoEQuantConfig,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
@ -475,10 +452,18 @@ def cutlass_moe_fp8(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
if per_act_token is None:
|
assert quant_config is not None
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
if quant_config.a1_scale is not None:
|
||||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
assert (quant_config.per_act_token_quant ==
|
||||||
|
quant_config.a1_scale.numel() != 1)
|
||||||
|
if quant_config.a2_scale is not None:
|
||||||
|
assert (quant_config.per_act_token_quant ==
|
||||||
|
quant_config.a2_scale.numel() != 1)
|
||||||
|
|
||||||
|
assert (quant_config.w1_scale is None
|
||||||
|
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
|
||||||
|
== w1_q.size(1))))
|
||||||
|
|
||||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
||||||
0)
|
0)
|
||||||
@ -487,12 +472,11 @@ def cutlass_moe_fp8(
|
|||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
out_dtype=a.dtype,
|
out_dtype=a.dtype,
|
||||||
per_act_token_quant=per_act_token,
|
|
||||||
per_out_ch_quant=per_out_ch,
|
|
||||||
ab_strides1=ab_strides1,
|
ab_strides1=ab_strides1,
|
||||||
ab_strides2=ab_strides2,
|
ab_strides2=ab_strides2,
|
||||||
c_strides1=c_strides1,
|
c_strides1=c_strides1,
|
||||||
c_strides2=c_strides2,
|
c_strides2=c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -502,14 +486,9 @@ def cutlass_moe_fp8(
|
|||||||
w2_q,
|
w2_q,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
False,
|
activation=activation,
|
||||||
activation,
|
global_num_experts=num_experts,
|
||||||
num_experts,
|
expert_map=expert_map,
|
||||||
expert_map,
|
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
# Split into batched and non-batched
|
||||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
g1_alphas: torch.Tensor,
|
|
||||||
g2_alphas: torch.Tensor,
|
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
a2_gscale: torch.Tensor,
|
|
||||||
max_experts_per_worker: int,
|
max_experts_per_worker: int,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
per_act_token_quant: bool,
|
quant_config: FusedMoEQuantConfig,
|
||||||
per_out_ch_quant: bool,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
use_batched_format: bool = False,
|
use_batched_format: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
# NVFP4 requires two levels of quantization, which involves
|
|
||||||
# computing some scaling factors dynamically. This makes it
|
|
||||||
# incompatible with the typical prepare -> MoE -> finalize
|
|
||||||
# pipeline. Move the quantization logic into the MoE body.
|
|
||||||
FusedMoEQuantConfig(
|
|
||||||
quant_dtype=None, # skip quantization in prepare/finalize
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
per_out_ch_quant=per_out_ch_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
self.max_experts_per_worker = max_experts_per_worker
|
self.max_experts_per_worker = max_experts_per_worker
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
self.use_batched_format = use_batched_format
|
self.use_batched_format = use_batched_format
|
||||||
|
|
||||||
# TODO(bnell): put this stuff into quant config?
|
|
||||||
self.g1_alphas = g1_alphas
|
|
||||||
self.g2_alphas = g2_alphas
|
|
||||||
self.a1_gscale = a1_gscale
|
|
||||||
self.a2_gscale = a2_gscale
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
self
|
self
|
||||||
@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: torch.Tensor,
|
a1q_scale: Optional[torch.Tensor], # unused
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: torch.Tensor,
|
|
||||||
workspace13: Optional[torch.Tensor],
|
workspace13: Optional[torch.Tensor],
|
||||||
workspace2: Optional[torch.Tensor],
|
workspace2: Optional[torch.Tensor],
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a=hidden_states,
|
a=hidden_states,
|
||||||
a1_gscale=self.a1_gscale,
|
a1_gscale=self.a1_gscale,
|
||||||
w1_fp4=w1,
|
w1_fp4=w1,
|
||||||
w1_blockscale=w1_scale,
|
w1_blockscale=self.w1_scale,
|
||||||
w1_alphas=self.g1_alphas,
|
w1_alphas=self.g1_alphas,
|
||||||
a2_gscale=self.a2_gscale,
|
a2_gscale=self.a2_gscale,
|
||||||
w2_fp4=w2,
|
w2_fp4=w2,
|
||||||
w2_blockscale=w2_scale,
|
w2_blockscale=self.w2_scale,
|
||||||
w2_alphas=self.g2_alphas,
|
w2_alphas=self.g2_alphas,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
@ -788,14 +741,9 @@ def cutlass_moe_fp4(
|
|||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w1_fp4: torch.Tensor,
|
w1_fp4: torch.Tensor,
|
||||||
w2_fp4: torch.Tensor,
|
w2_fp4: torch.Tensor,
|
||||||
w1_blockscale: torch.Tensor,
|
|
||||||
w2_blockscale: torch.Tensor,
|
|
||||||
g1_alphas: torch.Tensor,
|
|
||||||
g2_alphas: torch.Tensor,
|
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
a2_gscale: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
m: int,
|
m: int,
|
||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
@ -805,17 +753,31 @@ def cutlass_moe_fp4(
|
|||||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
"is currently not supported for "
|
"is currently not supported for "
|
||||||
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
||||||
|
|
||||||
|
# TODO(bnell): this feels a bit hacky
|
||||||
|
# NVFP4 requires two levels of quantization, which involves
|
||||||
|
# computing some scaling factors dynamically. This makes it
|
||||||
|
# incompatible with the typical prepare -> MoE -> finalize
|
||||||
|
# pipeline. Move the quantization logic into the MoE body.
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype=None, # skip quantization in prepare/finalize
|
||||||
|
per_act_token_quant=quant_config.per_act_token_quant,
|
||||||
|
per_out_ch_quant=quant_config.per_out_ch_quant,
|
||||||
|
block_shape=quant_config.block_shape,
|
||||||
|
g1_alphas=quant_config.g1_alphas,
|
||||||
|
g2_alphas=quant_config.g2_alphas,
|
||||||
|
a1_gscale=quant_config.a1_gscale,
|
||||||
|
a2_gscale=quant_config.a2_gscale,
|
||||||
|
w1_scale=quant_config.w1_scale,
|
||||||
|
w2_scale=quant_config.w2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
CutlassExpertsFp4(
|
CutlassExpertsFp4(
|
||||||
g1_alphas,
|
|
||||||
g2_alphas,
|
|
||||||
a1_gscale,
|
|
||||||
a2_gscale,
|
|
||||||
max_experts_per_worker=e,
|
max_experts_per_worker=e,
|
||||||
out_dtype=a.dtype,
|
out_dtype=a.dtype,
|
||||||
per_act_token_quant=False,
|
quant_config=quant_config,
|
||||||
per_out_ch_quant=False,
|
|
||||||
use_batched_format=False,
|
use_batched_format=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -830,10 +792,6 @@ def cutlass_moe_fp4(
|
|||||||
activation="silu",
|
activation="silu",
|
||||||
global_num_experts=e,
|
global_num_experts=e,
|
||||||
expert_map=None,
|
expert_map=None,
|
||||||
w1_scale=w1_blockscale,
|
|
||||||
w2_scale=w2_blockscale,
|
|
||||||
a1_scale=None,
|
|
||||||
a2_scale=None,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
|
||||||
def run_cutlass_block_scaled_fused_experts(
|
def run_cutlass_block_scaled_fused_experts(
|
||||||
a: torch.Tensor,
|
a: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,9 +8,11 @@ from tqdm import tqdm
|
|||||||
import vllm.envs as env
|
import vllm.envs as env
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||||
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
|
compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute,
|
||||||
|
deepgemm_unpermute_and_reduce)
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
|
||||||
def deep_gemm_block_shape() -> list[int]:
|
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
|
||||||
import deep_gemm as dg
|
|
||||||
block = dg.get_m_alignment_for_contiguous_layout()
|
|
||||||
return [block, block]
|
|
||||||
|
|
||||||
|
|
||||||
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
||||||
align = deep_gemm_block_shape()[0]
|
align = deep_gemm_block_shape()[0]
|
||||||
return align <= M and N % align == 0 and K % align == 0
|
return align <= M and N % align == 0 and K % align == 0
|
||||||
@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
|
|||||||
|
|
||||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig(
|
assert quant_config.block_shape == deep_gemm_block_shape()
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
assert quant_config.quant_dtype == torch.float8_e4m3fn
|
||||||
per_act_token_quant=False,
|
assert not quant_config.per_act_token_quant
|
||||||
block_shape=deep_gemm_block_shape(),
|
assert not quant_config.per_out_ch_quant
|
||||||
))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
assert self.block_shape is not None
|
|
||||||
assert a1q_scale is not None
|
assert a1q_scale is not None
|
||||||
assert w1_scale is not None
|
assert self.a2_scale is None
|
||||||
assert w2_scale is not None
|
assert self.block_shape is not None
|
||||||
|
assert self.w1_scale is not None
|
||||||
|
assert self.w2_scale is not None
|
||||||
|
|
||||||
a1q = hidden_states
|
a1q = hidden_states
|
||||||
_, N, K = w1.size()
|
_, N, K = w1.size()
|
||||||
@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
aq_out=a1q_perm)
|
aq_out=a1q_perm)
|
||||||
assert a1q.size(0) == M_sum
|
assert a1q.size(0) == M_sum
|
||||||
|
|
||||||
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
|
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale),
|
||||||
mm1_out, expert_ids)
|
mm1_out, expert_ids)
|
||||||
|
|
||||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||||
@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
column_major_scales=True,
|
column_major_scales=True,
|
||||||
out_q=quant_out)
|
out_q=quant_out)
|
||||||
|
|
||||||
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
|
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale),
|
||||||
mm2_out, expert_ids)
|
mm2_out, expert_ids)
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
@ -348,9 +336,16 @@ def deep_gemm_moe_fp8(
|
|||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||||
"""
|
"""
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=deep_gemm_block_shape())
|
||||||
|
|
||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
DeepGemmExperts(),
|
DeepGemmExperts(quant_config),
|
||||||
)
|
)
|
||||||
return fn(
|
return fn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -358,13 +353,9 @@ def deep_gemm_moe_fp8(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
inplace,
|
inplace=inplace,
|
||||||
activation,
|
activation=activation,
|
||||||
global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare_async(
|
def prepare_async(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# Quant and Dispatch
|
# Quant and Dispatch
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1,
|
a1,
|
||||||
a1_scale,
|
quant_config.a1_scale,
|
||||||
quant_dtype=quant_config.quant_dtype,
|
quant_dtype=quant_config.quant_dtype,
|
||||||
per_act_token_quant=quant_config.per_act_token_quant,
|
per_act_token_quant=quant_config.per_act_token_quant,
|
||||||
block_shape=quant_config.block_shape,
|
block_shape=quant_config.block_shape,
|
||||||
@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
else:
|
else:
|
||||||
a1q = a1
|
a1q = a1
|
||||||
a1q_scale = None
|
a1q_scale = None
|
||||||
a1_post_scale = a1_scale
|
a1_post_scale = quant_config.a1_scale
|
||||||
|
|
||||||
return (lambda *args: None,
|
return (lambda *args: None,
|
||||||
self._do_dispatch(tokens=a1q,
|
self._do_dispatch(tokens=a1q,
|
||||||
@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
|
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
|
||||||
topk_weights, topk_ids, num_experts,
|
num_experts, expert_map,
|
||||||
expert_map,
|
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
quant_config)
|
quant_config)
|
||||||
return receiver()
|
return receiver()
|
||||||
|
|||||||
@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def _do_quant(
|
def _do_quant(
|
||||||
self,
|
self,
|
||||||
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a1_dtype: torch.dtype,
|
a1_dtype: torch.dtype,
|
||||||
quant_dtype: Union[torch.dtype, str, None],
|
quant_config: FusedMoEQuantConfig,
|
||||||
per_act_token_quant: bool,
|
|
||||||
block_shape: Optional[list[int]],
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
|
||||||
block_k = block_shape[1] if block_shape is not None else None
|
|
||||||
if self.use_fp8_dispatch:
|
if self.use_fp8_dispatch:
|
||||||
|
block_k = quant_config.block_shape[
|
||||||
|
1] if quant_config.block_shape is not None else None
|
||||||
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
||||||
# DeepEP kernels did the quantization for us.
|
# DeepEP kernels did the quantization for us.
|
||||||
x, x_scales = x
|
x, x_scales = x
|
||||||
@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
# TODO (varun): Optimization - Use a batched version of quant
|
# TODO (varun): Optimization - Use a batched version of quant
|
||||||
x = x.view((-1, hidden_dim))
|
x = x.view((-1, hidden_dim))
|
||||||
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
|
x, x_scales = moe_kernel_quantize_input(
|
||||||
per_act_token_quant,
|
x, quant_config.a1_scale, quant_config.quant_dtype,
|
||||||
block_shape)
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
x = x.view((num_experts, -1, hidden_dim))
|
x = x.view((num_experts, -1, hidden_dim))
|
||||||
|
|
||||||
if quant_dtype is not None:
|
if quant_config.quant_dtype is not None:
|
||||||
assert x_scales is not None
|
assert x_scales is not None
|
||||||
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
||||||
|
|
||||||
@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare_async(
|
def prepare_async(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
assert hidden_size % 128 == 0, \
|
assert hidden_size % 128 == 0, \
|
||||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||||
|
|
||||||
has_per_token_scales = a1_scale.numel(
|
has_per_token_scales = quant_config.a1_scale.numel(
|
||||||
) != 1 if a1_scale is not None else (
|
) != 1 if quant_config.a1_scale is not None else (
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
quant_config.a2_scale.numel() != 1
|
||||||
|
if quant_config.a2_scale is not None else False)
|
||||||
assert not has_per_token_scales, (
|
assert not has_per_token_scales, (
|
||||||
"low_latency kernels doesn't support dispatching per-token scales")
|
"low_latency kernels doesn't support dispatching per-token scales")
|
||||||
|
|
||||||
@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return_recv_hook=True)
|
return_recv_hook=True)
|
||||||
self.handles[a2a_idx] = handle
|
self.handles[a2a_idx] = handle
|
||||||
|
|
||||||
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
|
return (
|
||||||
a1_scale, a1.dtype, quant_config))
|
hook,
|
||||||
|
lambda: self._receiver(expert_x, expert_num_tokens, quant_config.
|
||||||
|
a1_scale, a1.dtype, quant_config))
|
||||||
|
|
||||||
def _receiver(
|
def _receiver(
|
||||||
self,
|
self,
|
||||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||||
expert_num_tokens: torch.Tensor,
|
expert_num_tokens: torch.Tensor,
|
||||||
a1_scale,
|
a1_scale: Optional[torch.Tensor],
|
||||||
a1_dtype,
|
a1_dtype: torch.dtype,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
expert_x, expert_x_scale = self._do_quant(
|
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype,
|
||||||
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
quant_config)
|
||||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
|
||||||
|
|
||||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
||||||
@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
|
hook, receiver = self.prepare_async(a1, topk_weights, topk_ids,
|
||||||
topk_weights, topk_ids,
|
|
||||||
num_experts, expert_map,
|
num_experts, expert_map,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
quant_config)
|
quant_config)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
g1_alphas: torch.Tensor,
|
|
||||||
g2_alphas: torch.Tensor,
|
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
a2_gscale: torch.Tensor,
|
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
quant_dtype: Union[torch.dtype, str, None],
|
quant_config: FusedMoEQuantConfig,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig(
|
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
||||||
quant_dtype=quant_dtype,
|
|
||||||
per_act_token_quant=False,
|
|
||||||
block_shape=None,
|
|
||||||
))
|
|
||||||
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
|
||||||
"Only nvfp4,fp8 quantization are currently supported.")
|
"Only nvfp4,fp8 quantization are currently supported.")
|
||||||
self.ep_rank = ep_rank
|
self.ep_rank = ep_rank
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.g1_alphas = g1_alphas
|
|
||||||
self.g2_alphas = g2_alphas
|
|
||||||
self.a1_gscale = a1_gscale
|
|
||||||
self.a2_gscale = a2_gscale
|
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor], # Not used
|
|
||||||
workspace13: Optional[torch.Tensor],
|
workspace13: Optional[torch.Tensor],
|
||||||
workspace2: Optional[torch.Tensor],
|
workspace2: Optional[torch.Tensor],
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
fc2_expert_weights = w2
|
fc2_expert_weights = w2
|
||||||
else:
|
else:
|
||||||
# Ensure w1_scale and w2_scale are not None before calling view
|
# Ensure w1_scale and w2_scale are not None before calling view
|
||||||
assert w1_scale is not None and w2_scale is not None, (
|
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||||
"w1_scale and w2_scale must not "
|
"w1_scale and w2_scale must not "
|
||||||
"be None for FlashInferExperts")
|
"be None for FlashInferExperts")
|
||||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||||
# min because inv_scale.
|
# min because inv_scale.
|
||||||
quant_scales = [
|
quant_scales = [
|
||||||
self.a1_gscale,
|
self.a1_gscale,
|
||||||
w1_scale.view(torch.int32),
|
self.w1_scale.view(torch.int32),
|
||||||
self.g1_alphas,
|
self.g1_alphas,
|
||||||
self.a2_gscale,
|
self.a2_gscale,
|
||||||
w2_scale.view(torch.int32),
|
self.w2_scale.view(torch.int32),
|
||||||
self.g2_alphas,
|
self.g2_alphas,
|
||||||
]
|
]
|
||||||
# FlashInfer API requires weight to be long for nvfp4
|
# FlashInfer API requires weight to be long for nvfp4
|
||||||
@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4(
|
|||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
quant_config: FusedMoEQuantConfig,
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
g1_alphas: torch.Tensor,
|
|
||||||
g2_alphas: torch.Tensor,
|
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
a2_gscale: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
fused_experts = mk.FusedMoEModularKernel(
|
fused_experts = mk.FusedMoEModularKernel(
|
||||||
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
|
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
|
||||||
a1_gscale=a1_gscale),
|
|
||||||
FlashInferExperts(
|
FlashInferExperts(
|
||||||
g1_alphas=g1_alphas,
|
|
||||||
g2_alphas=g2_alphas,
|
|
||||||
a1_gscale=a1_gscale,
|
|
||||||
a2_gscale=a2_gscale,
|
|
||||||
out_dtype=hidden_states.dtype,
|
out_dtype=hidden_states.dtype,
|
||||||
quant_dtype="nvfp4",
|
quant_config=quant_config,
|
||||||
))
|
))
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4(
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_dp: bool,
|
use_dp: bool,
|
||||||
a1_gscale: Optional[torch.Tensor],
|
|
||||||
num_dispatchers: int = 1,
|
num_dispatchers: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_dispatchers_ = num_dispatchers
|
self.num_dispatchers_ = num_dispatchers
|
||||||
self.use_dp = use_dp
|
self.use_dp = use_dp
|
||||||
self.a1_gscale = a1_gscale
|
|
||||||
self.local_tokens = None
|
self.local_tokens = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor], # Not used
|
|
||||||
a2_scale: Optional[torch.Tensor], # Not used
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
# TODO(bnell): use quant_config + scales instead of ctor args
|
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
|
|
||||||
@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1,
|
a1,
|
||||||
self.a1_gscale,
|
quant_config.a1_gscale,
|
||||||
quant_config.quant_dtype,
|
quant_config.quant_dtype,
|
||||||
quant_config.per_act_token_quant,
|
quant_config.per_act_token_quant,
|
||||||
quant_config.block_shape,
|
quant_config.block_shape,
|
||||||
|
|||||||
185
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Normal file
185
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import List # noqa: UP035
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
|
moe_kernel_quantize_input)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
calculate_tile_tokens_dim)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fused_moe_blockscale_fp8(
|
||||||
|
routing_logits: torch.Tensor,
|
||||||
|
routing_bias: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w13_weight: torch.Tensor,
|
||||||
|
w13_weight_scale_inv: torch.Tensor,
|
||||||
|
w2_weight: torch.Tensor,
|
||||||
|
w2_weight_scale_inv: torch.Tensor,
|
||||||
|
global_num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
expert_offset: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
block_shape: List[int], #noqa: UP006
|
||||||
|
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||||
|
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||||
|
assert top_k <= global_num_experts
|
||||||
|
assert top_k <= 8
|
||||||
|
assert topk_group <= 4
|
||||||
|
assert global_num_experts > num_expert_group
|
||||||
|
assert global_num_experts % num_expert_group == 0
|
||||||
|
assert global_num_experts % 4 == 0
|
||||||
|
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||||
|
assert block_shape == [128, 128]
|
||||||
|
|
||||||
|
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||||
|
# NOTE: scales of hidden states have to be transposed!
|
||||||
|
a_sf_t = a_sf.t().contiguous()
|
||||||
|
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||||
|
routing_logits=routing_logits,
|
||||||
|
routing_bias=routing_bias,
|
||||||
|
hidden_states=a_q,
|
||||||
|
hidden_states_scale=a_sf_t,
|
||||||
|
gemm1_weights=w13_weight,
|
||||||
|
gemm1_weights_scale=w13_weight_scale_inv,
|
||||||
|
gemm2_weights=w2_weight,
|
||||||
|
gemm2_weights_scale=w2_weight_scale_inv,
|
||||||
|
num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
n_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
local_expert_offset=expert_offset,
|
||||||
|
local_num_experts=local_num_experts,
|
||||||
|
routed_scaling_factor=routed_scaling,
|
||||||
|
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
||||||
|
global_num_experts),
|
||||||
|
routing_method_type=2, # DeepSeek-styled routing method
|
||||||
|
use_shuffled_weight=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||||
|
routing_logits: torch.Tensor,
|
||||||
|
routing_bias: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w13_weight: torch.Tensor,
|
||||||
|
w13_weight_scale_inv: torch.Tensor,
|
||||||
|
w2_weight: torch.Tensor,
|
||||||
|
w2_weight_scale_inv: torch.Tensor,
|
||||||
|
global_num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: int,
|
||||||
|
topk_group: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
expert_offset: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
block_shape: list[int],
|
||||||
|
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.empty_like(x)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(bnell): Does this really need to be a torch.op?
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||||
|
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||||
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||||
|
routing_logits: torch.Tensor,
|
||||||
|
routing_bias: Optional[torch.Tensor],
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
gemm1_weights: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
|
output1_scales_scalar: torch.Tensor,
|
||||||
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
|
output2_scales_scalar: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
intermediate_size: int,
|
||||||
|
local_expert_offset: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
use_routing_scales_on_input: bool,
|
||||||
|
routing_method_type: int,
|
||||||
|
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||||
|
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||||
|
topk_group = topk_group if topk_group is not None else 0
|
||||||
|
|
||||||
|
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||||
|
hidden_states,
|
||||||
|
input_scale,
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
per_act_token_quant=False)
|
||||||
|
|
||||||
|
from vllm.utils.flashinfer import (
|
||||||
|
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||||
|
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||||
|
routing_logits=routing_logits,
|
||||||
|
routing_bias=routing_bias,
|
||||||
|
hidden_states=quant_hidden_states,
|
||||||
|
gemm1_weights=gemm1_weights,
|
||||||
|
output1_scales_scalar=output1_scales_scalar,
|
||||||
|
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||||
|
gemm2_weights=gemm2_weights,
|
||||||
|
output2_scales_scalar=output2_scales_scalar,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
n_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
local_expert_offset=local_expert_offset,
|
||||||
|
local_num_experts=local_num_experts,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||||
|
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
||||||
|
top_k, num_experts),
|
||||||
|
routing_method_type=routing_method_type)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||||
|
routing_logits: torch.Tensor,
|
||||||
|
routing_bias: Optional[torch.Tensor],
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_scale: torch.Tensor,
|
||||||
|
gemm1_weights: torch.Tensor,
|
||||||
|
gemm2_weights: torch.Tensor,
|
||||||
|
output1_scales_scalar: torch.Tensor,
|
||||||
|
output1_scales_gate_scalar: torch.Tensor,
|
||||||
|
output2_scales_scalar: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
top_k: int,
|
||||||
|
num_expert_group: Optional[int],
|
||||||
|
topk_group: Optional[int],
|
||||||
|
intermediate_size: int,
|
||||||
|
local_expert_offset: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
use_routing_scales_on_input: bool,
|
||||||
|
routing_method_type: int,
|
||||||
|
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(bnell): Does this really need to be a torch.op?
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||||
|
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||||
|
mutates_args=["hidden_states"],
|
||||||
|
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||||
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||||
|
)
|
||||||
@ -8,7 +8,7 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
get_config_dtype_str, try_get_optimal_moe_config)
|
try_get_optimal_moe_config)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
|
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=a1.device)
|
device=a1.device)
|
||||||
else:
|
else:
|
||||||
assert a1_scale is None
|
assert quant_config.a1_scale is None
|
||||||
b_a1_scale = None
|
b_a1_scale = None
|
||||||
|
|
||||||
first_expert = num_local_experts * self.rank
|
first_expert = num_local_experts * self.rank
|
||||||
last_expert = first_expert + num_local_experts
|
last_expert = first_expert + num_local_experts
|
||||||
|
|
||||||
a1_scale = normalize_scales_shape(a1_scale)
|
a1_scale = normalize_scales_shape(quant_config.a1_scale)
|
||||||
a2_scale = normalize_scales_shape(a2_scale)
|
|
||||||
|
|
||||||
for expert_id in range(first_expert, last_expert):
|
for expert_id in range(first_expert, last_expert):
|
||||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||||
@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self,
|
self,
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
use_fp8_w8a8: bool = False,
|
quant_config: FusedMoEQuantConfig,
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
use_mxfp4_w4a4: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
per_act_token_quant: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig.make(
|
assert not self.quant_config.use_int8_w8a8, "NYI"
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
assert not self.quant_config.use_int8_w8a16, "NYI"
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
assert not self.quant_config.use_int4_w4a16, "NYI"
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
|
||||||
assert not use_int8_w8a16, "NYI"
|
|
||||||
assert not use_int4_w4a16, "NYI"
|
|
||||||
assert not use_mxfp4_w4a4, "NYI"
|
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
tmp = _resize_cache(workspace2, (num, N))
|
tmp = _resize_cache(workspace2, (num, N))
|
||||||
|
|
||||||
if self.quant_config.is_quantized:
|
if self.quant_config.is_quantized:
|
||||||
assert a1q_scale is not None and w1_scale is not None
|
assert a1q_scale is not None and self.w1_scale is not None
|
||||||
input = self.dequant(hidden_states[expert, :, :],
|
input = self.dequant(hidden_states[expert, :, :],
|
||||||
a1q_scale[expert])
|
a1q_scale[expert])
|
||||||
w1_dq = self.dequant(w1[expert], w1_scale[expert])
|
w1_dq = self.dequant(w1[expert], self.w1_scale[expert])
|
||||||
input = input[:num] @ w1_dq.transpose(0, 1)
|
input = input[:num] @ w1_dq.transpose(0, 1)
|
||||||
else:
|
else:
|
||||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
|
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
|
||||||
@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.activation(activation, tmp, input.to(tmp.dtype))
|
self.activation(activation, tmp, input.to(tmp.dtype))
|
||||||
|
|
||||||
if self.quant_config.is_quantized:
|
if self.quant_config.is_quantized:
|
||||||
assert w2_scale is not None
|
assert self.w2_scale is not None
|
||||||
w2_dq = self.dequant(w2[expert], w2_scale[expert])
|
w2_dq = self.dequant(w2[expert], self.w2_scale[expert])
|
||||||
else:
|
else:
|
||||||
w2_dq = w2[expert]
|
w2_dq = w2[expert]
|
||||||
|
|
||||||
@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self,
|
self,
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
use_fp8_w8a8: bool = False,
|
quant_config: FusedMoEQuantConfig,
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
use_mxfp4_w4a4: bool = False,
|
|
||||||
per_act_token_quant: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig.make(
|
assert not self.quant_config.use_int8_w8a8, "NYI"
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
assert not self.quant_config.use_int8_w8a16, "NYI"
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
assert not self.quant_config.use_int4_w4a16, "NYI"
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
assert not use_int8_w8a8, "NYI"
|
|
||||||
assert not use_int8_w8a16, "NYI"
|
|
||||||
assert not use_int4_w4a16, "NYI"
|
|
||||||
assert not use_mxfp4_w4a4, "NYI"
|
|
||||||
assert max_num_tokens > 0
|
assert max_num_tokens > 0
|
||||||
assert num_dispatchers > 0
|
assert num_dispatchers > 0
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
|
||||||
self.use_int8_w8a8 = use_int8_w8a8
|
|
||||||
self.use_int4_w4a16 = use_int4_w4a16
|
|
||||||
self.use_int8_w8a16 = use_int8_w8a16
|
|
||||||
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
|
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if self.use_int4_w4a16:
|
if self.quant_config.use_int4_w4a16:
|
||||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||||
"Hidden size mismatch")
|
"Hidden size mismatch")
|
||||||
else:
|
else:
|
||||||
@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
assert w1.size(0) == E
|
assert w1.size(0) == E
|
||||||
assert w2.size(0) == E
|
assert w2.size(0) == E
|
||||||
|
|
||||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
config_dtype = self.quant_config.config_name(hidden_states.dtype)
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
config = try_get_optimal_moe_config(
|
config = try_get_optimal_moe_config(
|
||||||
w1.size(),
|
w1.size(),
|
||||||
@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
intermediate_cache2 = _resize_cache(workspace2,
|
intermediate_cache2 = _resize_cache(workspace2,
|
||||||
(E, max_num_tokens, N // 2))
|
(E, max_num_tokens, N // 2))
|
||||||
|
|
||||||
if self.use_fp8_w8a8:
|
# TODO(bnell): should this be done for any quantized type?
|
||||||
|
if self.quant_config.use_fp8_w8a8:
|
||||||
intermediate_cache1.fill_(0)
|
intermediate_cache1.fill_(0)
|
||||||
|
|
||||||
a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
|
a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
|
||||||
@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_num_tokens=expert_num_tokens,
|
expert_num_tokens=expert_num_tokens,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
A_scale=a1q_scale,
|
A_scale=a1q_scale,
|
||||||
B_scale=w1_scale,
|
B_scale=self.w1_scale,
|
||||||
B_zp=w1_zp,
|
B_zp=self.w1_zp,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||||
config=config,
|
config=config,
|
||||||
per_act_token_quant=self.per_act_token_quant,
|
per_act_token_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape)
|
||||||
@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||||
intermediate_cache2, a2_scale, max_num_tokens, E, N,
|
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
|
||||||
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
|
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
|
||||||
self.block_shape)
|
self.block_shape)
|
||||||
|
|
||||||
@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_num_tokens=expert_num_tokens,
|
expert_num_tokens=expert_num_tokens,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
A_scale=a2q_scale,
|
A_scale=a2q_scale,
|
||||||
B_scale=w2_scale,
|
B_scale=self.w2_scale,
|
||||||
B_zp=w2_zp,
|
B_zp=self.w2_zp,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||||
config=config,
|
config=config,
|
||||||
per_act_token_quant=self.per_act_token_quant,
|
per_act_token_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape)
|
block_shape=self.block_shape)
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Fused MoE kernel."""
|
"""Fused MoE Triton kernels."""
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
# torch.compile needs typing.List. It will fail torch.library.infer_schema
|
# torch.compile needs typing.List. It will fail torch.library.infer_schema
|
||||||
# otherwise
|
# otherwise
|
||||||
from typing import List # noqa: UP035
|
from typing import List # noqa: UP035
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -18,7 +18,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig, get_config_quant_dtype)
|
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
_valid_cutlass_block_scaled_grouped_gemm,
|
_valid_cutlass_block_scaled_grouped_gemm,
|
||||||
run_cutlass_block_scaled_fused_experts)
|
run_cutlass_block_scaled_fused_experts)
|
||||||
@ -32,11 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP)
|
TopKWeightAndReduceNoOP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
_resize_cache, moe_kernel_quantize_input)
|
_resize_cache, activation_without_mul, moe_kernel_quantize_input)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|
||||||
calculate_tile_tokens_dim)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|
||||||
per_token_group_quant_fp8)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
dequant_mxfp4)
|
dequant_mxfp4)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -1049,87 +1045,66 @@ def fused_grouped_topk(
|
|||||||
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(
|
|
||||||
dtype: torch.dtype,
|
|
||||||
use_int4_w4a16: Optional[bool] = False,
|
|
||||||
use_int8_w8a16: Optional[bool] = False,
|
|
||||||
use_fp8_w8a8: Optional[bool] = False,
|
|
||||||
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
|
|
||||||
if use_fp8_w8a8:
|
|
||||||
return "fp8_w8a8"
|
|
||||||
elif use_int8_w8a16:
|
|
||||||
return "int8_w8a16"
|
|
||||||
elif use_int4_w4a16:
|
|
||||||
return "int4_w4a16"
|
|
||||||
elif use_mxfp4_w4a4:
|
|
||||||
return "mxfp4_w4a4"
|
|
||||||
elif dtype == torch.float:
|
|
||||||
# avoiding cases where kernel fails when float32 MoE
|
|
||||||
# use fp16/bfloat16 configs
|
|
||||||
return "float32"
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts(
|
def inplace_fused_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
is_act_and_mul: bool = True,
|
apply_router_weight_on_input: bool = False,
|
||||||
apply_router_weight_on_input: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_mxfp4_w4a4: bool = False,
|
||||||
use_mxfp4_w4a4: bool = False,
|
per_channel_quant: bool = False,
|
||||||
per_channel_quant: bool = False,
|
global_num_experts: int = -1,
|
||||||
global_num_experts: int = -1,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
) -> None:
|
||||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||||
activation, is_act_and_mul,
|
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||||
apply_router_weight_on_input, use_fp8_w8a8,
|
|
||||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||||
a2_scale, block_shape, w1_bias, w2_bias)
|
a2_scale, block_shape, w1_bias, w2_bias)
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
|
def inplace_fused_experts_fake(
|
||||||
w1: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
activation: str = "silu",
|
topk_ids: torch.Tensor,
|
||||||
is_act_and_mul: bool = True,
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
use_mxfp4_w4a4: bool = False,
|
use_mxfp4_w4a4: bool = False,
|
||||||
per_channel_quant: bool = False,
|
per_channel_quant: bool = False,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -1143,175 +1118,6 @@ direct_register_custom_op(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_fused_moe_blockscale_fp8(
|
|
||||||
routing_logits: torch.Tensor,
|
|
||||||
routing_bias: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w13_weight: torch.Tensor,
|
|
||||||
w13_weight_scale_inv: torch.Tensor,
|
|
||||||
w2_weight: torch.Tensor,
|
|
||||||
w2_weight_scale_inv: torch.Tensor,
|
|
||||||
global_num_experts: int,
|
|
||||||
top_k: int,
|
|
||||||
num_expert_group: int,
|
|
||||||
topk_group: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
expert_offset: int,
|
|
||||||
local_num_experts: int,
|
|
||||||
block_shape: List[int], #noqa: UP006
|
|
||||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
|
||||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
|
||||||
assert top_k <= global_num_experts
|
|
||||||
assert top_k <= 8
|
|
||||||
assert topk_group <= 4
|
|
||||||
assert global_num_experts > num_expert_group
|
|
||||||
assert global_num_experts % num_expert_group == 0
|
|
||||||
assert global_num_experts % 4 == 0
|
|
||||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
|
||||||
assert block_shape == [128, 128]
|
|
||||||
|
|
||||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
|
||||||
# NOTE: scales of hidden states have to be transposed!
|
|
||||||
a_sf_t = a_sf.t().contiguous()
|
|
||||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
|
||||||
routing_logits=routing_logits,
|
|
||||||
routing_bias=routing_bias,
|
|
||||||
hidden_states=a_q,
|
|
||||||
hidden_states_scale=a_sf_t,
|
|
||||||
gemm1_weights=w13_weight,
|
|
||||||
gemm1_weights_scale=w13_weight_scale_inv,
|
|
||||||
gemm2_weights=w2_weight,
|
|
||||||
gemm2_weights_scale=w2_weight_scale_inv,
|
|
||||||
num_experts=global_num_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
n_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
local_expert_offset=expert_offset,
|
|
||||||
local_num_experts=local_num_experts,
|
|
||||||
routed_scaling_factor=routed_scaling,
|
|
||||||
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
|
||||||
global_num_experts),
|
|
||||||
routing_method_type=2, # DeepSeek-styled routing method
|
|
||||||
use_shuffled_weight=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
|
||||||
routing_logits: torch.Tensor,
|
|
||||||
routing_bias: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w13_weight: torch.Tensor,
|
|
||||||
w13_weight_scale_inv: torch.Tensor,
|
|
||||||
w2_weight: torch.Tensor,
|
|
||||||
w2_weight_scale_inv: torch.Tensor,
|
|
||||||
global_num_experts: int,
|
|
||||||
top_k: int,
|
|
||||||
num_expert_group: int,
|
|
||||||
topk_group: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
expert_offset: int,
|
|
||||||
local_num_experts: int,
|
|
||||||
block_shape: list[int],
|
|
||||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
|
||||||
return torch.empty_like(x)
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
|
||||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
|
||||||
routing_logits: torch.Tensor,
|
|
||||||
routing_bias: Optional[torch.Tensor],
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor,
|
|
||||||
gemm1_weights: torch.Tensor,
|
|
||||||
gemm2_weights: torch.Tensor,
|
|
||||||
output1_scales_scalar: torch.Tensor,
|
|
||||||
output1_scales_gate_scalar: torch.Tensor,
|
|
||||||
output2_scales_scalar: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
top_k: int,
|
|
||||||
num_expert_group: Optional[int],
|
|
||||||
topk_group: Optional[int],
|
|
||||||
intermediate_size: int,
|
|
||||||
local_expert_offset: int,
|
|
||||||
local_num_experts: int,
|
|
||||||
use_routing_scales_on_input: bool,
|
|
||||||
routing_method_type: int,
|
|
||||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
|
||||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
|
||||||
topk_group = topk_group if topk_group is not None else 0
|
|
||||||
|
|
||||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
|
||||||
hidden_states,
|
|
||||||
input_scale,
|
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
|
||||||
per_act_token_quant=False)
|
|
||||||
|
|
||||||
from vllm.utils.flashinfer import (
|
|
||||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
|
||||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
|
||||||
routing_logits=routing_logits,
|
|
||||||
routing_bias=routing_bias,
|
|
||||||
hidden_states=quant_hidden_states,
|
|
||||||
gemm1_weights=gemm1_weights,
|
|
||||||
output1_scales_scalar=output1_scales_scalar,
|
|
||||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
|
||||||
gemm2_weights=gemm2_weights,
|
|
||||||
output2_scales_scalar=output2_scales_scalar,
|
|
||||||
num_experts=num_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
n_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
intermediate_size=intermediate_size,
|
|
||||||
local_expert_offset=local_expert_offset,
|
|
||||||
local_num_experts=local_num_experts,
|
|
||||||
routed_scaling_factor=routed_scaling_factor,
|
|
||||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
|
||||||
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
|
||||||
top_k, num_experts),
|
|
||||||
routing_method_type=routing_method_type)
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
|
||||||
routing_logits: torch.Tensor,
|
|
||||||
routing_bias: Optional[torch.Tensor],
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_scale: torch.Tensor,
|
|
||||||
gemm1_weights: torch.Tensor,
|
|
||||||
gemm2_weights: torch.Tensor,
|
|
||||||
output1_scales_scalar: torch.Tensor,
|
|
||||||
output1_scales_gate_scalar: torch.Tensor,
|
|
||||||
output2_scales_scalar: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
top_k: int,
|
|
||||||
num_expert_group: Optional[int],
|
|
||||||
topk_group: Optional[int],
|
|
||||||
intermediate_size: int,
|
|
||||||
local_expert_offset: int,
|
|
||||||
local_num_experts: int,
|
|
||||||
use_routing_scales_on_input: bool,
|
|
||||||
routing_method_type: int,
|
|
||||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
|
||||||
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
|
||||||
mutates_args=["hidden_states"],
|
|
||||||
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
|
||||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts(
|
def outplace_fused_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -1319,7 +1125,6 @@ def outplace_fused_experts(
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
is_act_and_mul: bool = True,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1341,37 +1146,37 @@ def outplace_fused_experts(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return fused_experts_impl(
|
return fused_experts_impl(
|
||||||
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||||
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8,
|
||||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
|
use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant,
|
||||||
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
|
global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts_fake(
|
def outplace_fused_experts_fake(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
is_act_and_mul: bool = True,
|
use_fp8_w8a8: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int4_w4a16: bool = False,
|
||||||
use_int4_w4a16: bool = False,
|
use_mxfp4_w4a4: bool = False,
|
||||||
use_mxfp4_w4a4: bool = False,
|
per_channel_quant: bool = False,
|
||||||
per_channel_quant: bool = False,
|
global_num_experts: int = -1,
|
||||||
global_num_experts: int = -1,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
@ -1403,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
|||||||
|
|
||||||
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
||||||
# torch ops.
|
# torch ops.
|
||||||
def fused_experts(hidden_states: torch.Tensor,
|
def fused_experts(
|
||||||
w1: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
inplace: bool = False,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
inplace: bool = False,
|
||||||
is_act_and_mul: bool = True,
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
global_num_experts: int = -1,
|
||||||
use_int8_w8a8: bool = False,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
use_int8_w8a16: bool = False,
|
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||||
use_int4_w4a16: bool = False,
|
allow_deep_gemm: bool = False,
|
||||||
use_mxfp4_w4a4: bool = False,
|
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||||
per_channel_quant: bool = False,
|
) -> torch.Tensor:
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
if quant_config is None:
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
use_fp8_w8a8 = quant_config.use_fp8_w8a8
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
allow_deep_gemm: bool = False,
|
|
||||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
|
||||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
# For now, disable DeepGemm for small N (<= 512) until better
|
# For now, disable DeepGemm for small N (<= 512) until better
|
||||||
# permute/unpermute ops are available.
|
# permute/unpermute ops are available.
|
||||||
# However, on B200, we use DeepGemm for all cases because they only support
|
# However, on B200, we use DeepGemm for all cases because they only support
|
||||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||||
# accuracy issue.
|
# accuracy issue.
|
||||||
if (allow_deep_gemm and use_fp8_w8a8 and
|
if (allow_deep_gemm and quant_config.use_fp8_w8a8 and
|
||||||
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
|
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
|
||||||
|
assert quant_config is not None
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
assert is_act_and_mul, (
|
|
||||||
"DeepGemm only supports is_act_and_mul=True for now.")
|
|
||||||
return deep_gemm_moe_fp8(
|
return deep_gemm_moe_fp8(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
@ -1452,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
w1_scale=quant_config.w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=quant_config.w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=quant_config.a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=quant_config.a2_scale,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
||||||
and _valid_cutlass_block_scaled_grouped_gemm(
|
and _valid_cutlass_block_scaled_grouped_gemm(
|
||||||
w1, w2, inplace, activation, apply_router_weight_on_input,
|
w1, w2, inplace, activation, apply_router_weight_on_input,
|
||||||
expert_map)):
|
expert_map)):
|
||||||
|
assert quant_config is not None
|
||||||
return run_cutlass_block_scaled_fused_experts(
|
return run_cutlass_block_scaled_fused_experts(
|
||||||
a=hidden_states,
|
a=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w1_scale=w1_scale,
|
w1_scale=quant_config.w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=quant_config.w2_scale,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids)
|
topk_ids=topk_ids)
|
||||||
else:
|
else:
|
||||||
@ -1478,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
is_act_and_mul=is_act_and_mul,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=quant_config.use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a8=quant_config.use_int8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=quant_config.use_int8_w8a16,
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=quant_config.use_int4_w4a16,
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
|
||||||
per_channel_quant=per_channel_quant,
|
per_channel_quant=quant_config.per_act_token_quant,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
w1_scale=quant_config.w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=quant_config.w2_scale,
|
||||||
w1_zp=w1_zp,
|
w1_zp=quant_config.w1_zp,
|
||||||
w2_zp=w2_zp,
|
w2_zp=quant_config.w2_zp,
|
||||||
a1_scale=a1_scale,
|
a1_scale=quant_config.a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=quant_config.a2_scale,
|
||||||
block_shape=block_shape,
|
block_shape=quant_config.block_shape,
|
||||||
w1_bias=w1_bias,
|
w1_bias=quant_config.w1_bias,
|
||||||
w2_bias=w2_bias,
|
w2_bias=quant_config.w2_bias)
|
||||||
)
|
|
||||||
|
|
||||||
|
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||||
|
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config_quant_dtype(
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a8: bool,
|
||||||
|
use_mxfp4_w4a4: bool,
|
||||||
|
) -> Union[None, torch.dtype, str]:
|
||||||
|
"""
|
||||||
|
Get the quantization type based on the quantization strategy flags.
|
||||||
|
We don't have a quant_config at this point so we need to work backwards.
|
||||||
|
A return type of None means no quantization is required because the
|
||||||
|
input is unquantized or has been quantized prior to calling
|
||||||
|
fused_experts_impl.
|
||||||
|
"""
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
elif use_int8_w8a8:
|
||||||
|
return torch.int8
|
||||||
|
elif use_mxfp4_w4a4:
|
||||||
|
return "mxfp4"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(
|
def fused_experts_impl(
|
||||||
@ -1508,7 +1328,6 @@ def fused_experts_impl(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
is_act_and_mul: bool = True,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
use_int8_w8a8: bool = False,
|
||||||
@ -1557,17 +1376,18 @@ def fused_experts_impl(
|
|||||||
# https://github.com/vllm-project/vllm/issues/5938
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
M = min(num_tokens, CHUNK_SIZE)
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
||||||
|
# quantized prior to calling fused_experts.
|
||||||
|
quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a8=use_int8_w8a8,
|
||||||
|
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
@ -1640,7 +1460,7 @@ def fused_experts_impl(
|
|||||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||||
A=curr_hidden_states,
|
A=curr_hidden_states,
|
||||||
A_scale=a1_scale,
|
A_scale=a1_scale,
|
||||||
quant_dtype=qtype,
|
quant_dtype=quant_dtype,
|
||||||
per_act_token_quant=per_channel_quant,
|
per_act_token_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
@ -1671,30 +1491,29 @@ def fused_experts_impl(
|
|||||||
B_bias=w1_bias)
|
B_bias=w1_bias)
|
||||||
|
|
||||||
# Activation function with multiplication
|
# Activation function with multiplication
|
||||||
if activation == "silu" and is_act_and_mul:
|
if activation == "silu":
|
||||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
elif activation == "gelu" and is_act_and_mul:
|
elif activation == "gelu":
|
||||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
elif activation == "swigluoai" and is_act_and_mul:
|
elif activation == "swigluoai":
|
||||||
# alpha = 1.702, limit = 7.0
|
# alpha = 1.702, limit = 7.0
|
||||||
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
||||||
intermediate_cache1.view(-1, N))
|
intermediate_cache1.view(-1, N))
|
||||||
# Activation function without multiplication
|
# Activation function without multiplication
|
||||||
elif activation == "silu":
|
elif activation == SILU_NO_MUL:
|
||||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||||
elif activation == "gelu":
|
elif activation == GELU_NO_MUL:
|
||||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
|
||||||
f"with is_act_and_mul={is_act_and_mul}.")
|
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
A=intermediate_cache2,
|
A=intermediate_cache2,
|
||||||
A_scale=a2_scale,
|
A_scale=a2_scale,
|
||||||
quant_dtype=qtype,
|
quant_dtype=quant_dtype,
|
||||||
per_act_token_quant=per_channel_quant,
|
per_act_token_quant=per_channel_quant,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
@ -1726,164 +1545,13 @@ def fused_experts_impl(
|
|||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def fused_moe(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
inplace: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
is_act_and_mul: bool = True,
|
|
||||||
use_grouped_topk: bool = False,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
use_mxfp4_w4a4: bool = False,
|
|
||||||
per_channel_quant: bool = False,
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
|
||||||
weights, w1 and w2, and top-k gating mechanism.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
|
||||||
- w1 (torch.Tensor): The first set of expert weights.
|
|
||||||
- w2 (torch.Tensor): The second set of expert weights.
|
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
|
||||||
(before softmax).
|
|
||||||
- topk (int): The number of top-k experts to select.
|
|
||||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
|
||||||
Defaults to False.
|
|
||||||
- activation (str): The activation function to apply after the first
|
|
||||||
MoE layer.
|
|
||||||
- is_act_and_mul (bool): If True, use activation-and-mul function for
|
|
||||||
activation (self-gated activation), otherwise use activation function
|
|
||||||
for activation (ungated activation).
|
|
||||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
|
||||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
|
||||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
|
||||||
note: Deepseekv2 model uses grouped_topk
|
|
||||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
|
||||||
products for w1 and w2. Defaults to False.
|
|
||||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
|
||||||
products for w1 and w2. Defaults to False.
|
|
||||||
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
|
||||||
activation to compute the inner products for w1 and w2.
|
|
||||||
Defaults to False.
|
|
||||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
|
||||||
activation to compute the inner products for w1 and w2.
|
|
||||||
Defaults to False.
|
|
||||||
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
|
|
||||||
OCP MXFP4 activation to compute the inner products for w1 and w2.
|
|
||||||
Defaults to False.
|
|
||||||
- global_num_experts (int): The total number of experts in the global
|
|
||||||
expert space.
|
|
||||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
|
||||||
from the global expert space to the local expert space of the expert
|
|
||||||
parallel shard.
|
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
w1.
|
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
w2.
|
|
||||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
a1.
|
|
||||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
|
||||||
a2.
|
|
||||||
- block_shape: (Optional[list[int]]): Optional block size for block-wise
|
|
||||||
quantization.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
|
||||||
"""
|
|
||||||
if not is_act_and_mul:
|
|
||||||
assert inplace is False, (
|
|
||||||
"is_act_and_mul=False is not supported with inplace=True")
|
|
||||||
|
|
||||||
if use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
|
|
||||||
topk, renormalize,
|
|
||||||
num_expert_group, topk_group)
|
|
||||||
elif custom_routing_function is None:
|
|
||||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
||||||
hidden_states, gating_output, topk, renormalize)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
|
||||||
hidden_states, gating_output, topk, renormalize)
|
|
||||||
|
|
||||||
return fused_experts(hidden_states,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=inplace,
|
|
||||||
activation=activation,
|
|
||||||
is_act_and_mul=is_act_and_mul,
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_channel_quant=per_channel_quant,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=w1_zp,
|
|
||||||
w2_zp=w2_zp,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
block_shape=block_shape,
|
|
||||||
w1_bias=w1_bias,
|
|
||||||
w2_bias=w2_bias)
|
|
||||||
|
|
||||||
|
|
||||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_fp8_w8a8: bool = False,
|
quant_config: FusedMoEQuantConfig,
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
use_mxfp4_w4a4: bool = False,
|
|
||||||
per_act_token_quant: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig.make(
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
|
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
|
||||||
self.use_int4_w4a16 = use_int4_w4a16
|
|
||||||
self.use_int8_w8a8 = use_int8_w8a8
|
|
||||||
self.use_int8_w8a16 = use_int8_w8a16
|
|
||||||
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -1929,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
):
|
):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if self.use_int4_w4a16:
|
if self.quant_config.use_int4_w4a16:
|
||||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||||
"Hidden size mismatch")
|
"Hidden size mismatch")
|
||||||
else:
|
else:
|
||||||
@ -1964,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
if global_num_experts == -1:
|
if global_num_experts == -1:
|
||||||
global_num_experts = E
|
global_num_experts = E
|
||||||
|
|
||||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
config = try_get_optimal_moe_config(
|
config = try_get_optimal_moe_config(
|
||||||
w1.size(),
|
w1.size(),
|
||||||
w2.size(),
|
w2.size(),
|
||||||
top_k_num,
|
top_k_num,
|
||||||
config_dtype,
|
self.quant_config.config_name(hidden_states.dtype),
|
||||||
num_tokens,
|
num_tokens,
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
)
|
)
|
||||||
@ -2008,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
a1q_scale,
|
a1q_scale,
|
||||||
w1_scale,
|
self.w1_scale,
|
||||||
w1_zp,
|
self.w1_zp,
|
||||||
None, # topk_weights
|
None, # topk_weights
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
@ -2018,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
top_k_num,
|
top_k_num,
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||||
use_int8_w8a8=self.use_int8_w8a8,
|
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||||
per_channel_quant=self.per_act_token_quant,
|
per_channel_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
B_bias=None # TODO support B_bias
|
B_bias=self.w1_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.activation(activation, intermediate_cache2,
|
self.activation(activation, intermediate_cache2,
|
||||||
@ -2033,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a2q_scale: Optional[torch.Tensor] = None
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
intermediate_cache2, self.a2_scale, self.quant_dtype,
|
||||||
self.per_act_token_quant, self.block_shape)
|
self.per_act_token_quant, self.block_shape)
|
||||||
|
|
||||||
invoke_fused_moe_kernel(
|
invoke_fused_moe_kernel(
|
||||||
@ -2041,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
a2q_scale,
|
a2q_scale,
|
||||||
w2_scale,
|
self.w2_scale,
|
||||||
w2_zp,
|
self.w2_zp,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
@ -2051,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
1,
|
1,
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||||
use_int8_w8a8=self.use_int8_w8a8,
|
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||||
use_int8_w8a16=self.use_int8_w8a16,
|
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||||
use_int4_w4a16=self.use_int4_w4a16,
|
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||||
per_channel_quant=self.per_act_token_quant,
|
per_channel_quant=self.per_act_token_quant,
|
||||||
block_shape=self.block_shape,
|
block_shape=self.block_shape,
|
||||||
B_bias=None # TODO support B_bias
|
B_bias=self.w2_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
ops.moe_sum(intermediate_cache3, output)
|
ops.moe_sum(intermediate_cache3, output)
|
||||||
|
|
||||||
|
|
||||||
def modular_triton_fused_moe(
|
def modular_triton_fused_moe(
|
||||||
use_fp8_w8a8: bool,
|
quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel:
|
||||||
use_int8_w8a8: bool,
|
|
||||||
use_int8_w8a16: bool,
|
|
||||||
use_int4_w4a16: bool,
|
|
||||||
use_mxfp4_w4a4: bool,
|
|
||||||
per_act_token_quant: bool,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
) -> mk.FusedMoEModularKernel:
|
|
||||||
return mk.FusedMoEModularKernel(
|
return mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
TritonExperts(
|
TritonExperts(quant_config),
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.utils import has_triton_kernels
|
from vllm.utils import has_triton_kernels
|
||||||
@ -23,9 +25,6 @@ if has_triton_kernels():
|
|||||||
"Failed to import Triton kernels. Please make sure your triton "
|
"Failed to import Triton kernels. Please make sure your triton "
|
||||||
"version is compatible.")
|
"version is compatible.")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
|
||||||
|
|
||||||
|
|
||||||
def triton_kernel_moe_forward(
|
def triton_kernel_moe_forward(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -35,20 +34,10 @@ def triton_kernel_moe_forward(
|
|||||||
topk: int,
|
topk: int,
|
||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
per_channel_quant: bool = False,
|
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
|
||||||
w1_precision: Optional["PrecisionConfig"] = None,
|
|
||||||
w2_precision: Optional["PrecisionConfig"] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
routing_data, gather_idx, scatter_idx = routing(gating_output,
|
routing_data, gather_idx, scatter_idx = routing(gating_output,
|
||||||
@ -64,20 +53,10 @@ def triton_kernel_moe_forward(
|
|||||||
gather_idx,
|
gather_idx,
|
||||||
scatter_idx,
|
scatter_idx,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
quant_config=quant_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
per_channel_quant=per_channel_quant,
|
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map)
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_bias=w1_bias,
|
|
||||||
w2_bias=w2_bias,
|
|
||||||
w1_precision=w1_precision,
|
|
||||||
w2_precision=w2_precision,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
block_shape=block_shape)
|
|
||||||
|
|
||||||
|
|
||||||
# This is a triton implementation of the fused_experts function
|
# This is a triton implementation of the fused_experts function
|
||||||
@ -90,28 +69,23 @@ def triton_kernel_fused_experts(
|
|||||||
gather_indx, # GatherIndx
|
gather_indx, # GatherIndx
|
||||||
scatter_indx, # ScatterIndx
|
scatter_indx, # ScatterIndx
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
|
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||||
swiglu_alpha: float = 1.702,
|
swiglu_alpha: float = 1.702,
|
||||||
swiglu_limit: float = 7.0,
|
swiglu_limit: float = 7.0,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
|
||||||
per_channel_quant: bool = False,
|
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
a1q_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_bias: Optional[torch.Tensor] = None,
|
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
|
||||||
w1_precision: Optional["PrecisionConfig"] = None,
|
|
||||||
w2_precision: Optional["PrecisionConfig"] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if quant_config is None:
|
||||||
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
|
|
||||||
# type check, uint8 means mxfp4
|
# type check, uint8 means mxfp4
|
||||||
assert hidden_states.dtype == torch.bfloat16
|
assert hidden_states.dtype == torch.bfloat16
|
||||||
assert w1_bias is None or w1_bias.dtype == torch.float32
|
assert (quant_config.w1_bias is None
|
||||||
assert w2_bias is None or w2_bias.dtype == torch.float32
|
or quant_config.w1_bias.dtype == torch.float32)
|
||||||
|
assert (quant_config.w2_bias is None
|
||||||
|
or quant_config.w2_bias.dtype == torch.float32)
|
||||||
|
|
||||||
# Shape check, only check non-mxfp4
|
# Shape check, only check non-mxfp4
|
||||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||||
@ -130,20 +104,20 @@ def triton_kernel_fused_experts(
|
|||||||
intermediate_cache1 = matmul_ogs(
|
intermediate_cache1 = matmul_ogs(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w1_bias,
|
quant_config.w1_bias,
|
||||||
routing_data,
|
routing_data,
|
||||||
gather_indx=gather_indx,
|
gather_indx=gather_indx,
|
||||||
precision_config=w1_precision,
|
precision_config=quant_config.w1_precision,
|
||||||
gammas=gammas if apply_router_weight_on_input else None,
|
gammas=gammas if apply_router_weight_on_input else None,
|
||||||
fused_activation=act)
|
fused_activation=act)
|
||||||
|
|
||||||
intermediate_cache3 = matmul_ogs(
|
intermediate_cache3 = matmul_ogs(
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
w2,
|
w2,
|
||||||
w2_bias,
|
quant_config.w2_bias,
|
||||||
routing_data,
|
routing_data,
|
||||||
scatter_indx=scatter_indx,
|
scatter_indx=scatter_indx,
|
||||||
precision_config=w2_precision,
|
precision_config=quant_config.w2_precision,
|
||||||
gammas=None if apply_router_weight_on_input else gammas,
|
gammas=None if apply_router_weight_on_input else gammas,
|
||||||
y=output_tensor,
|
y=output_tensor,
|
||||||
)
|
)
|
||||||
@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config,
|
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
w1_precision: "PrecisionConfig",
|
quant_config: FusedMoEQuantConfig,
|
||||||
w2_precision: "PrecisionConfig",
|
|
||||||
w1_bias: Optional[torch.Tensor],
|
|
||||||
w2_bias: Optional[torch.Tensor],
|
|
||||||
):
|
):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
self.w1_precision = w1_precision
|
|
||||||
self.w2_precision = w2_precision
|
|
||||||
self.w1_bias = w1_bias
|
|
||||||
self.w2_bias = w2_bias
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
None,
|
routing_data=None,
|
||||||
None,
|
gather_indx=None,
|
||||||
None,
|
scatter_indx=None,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
quant_config=self.quant_config,
|
||||||
apply_router_weight_on_input=False,
|
apply_router_weight_on_input=False,
|
||||||
use_fp8_w8a8=False,
|
|
||||||
per_channel_quant=False,
|
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
a1q_scale=a1q_scale)
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_bias=self.w1_bias,
|
|
||||||
w2_bias=self.w2_bias,
|
|
||||||
w1_precision=self.w1_precision,
|
|
||||||
w2_precision=self.w2_precision,
|
|
||||||
a1_scale=a1q_scale,
|
|
||||||
a2_scale=a2_scale)
|
|
||||||
|
|||||||
@ -22,7 +22,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig, FusedMoEParallelConfig)
|
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
|
||||||
|
FusedMoEQuantConfig, biased_moe_quant_config)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEActivationFormat, FusedMoEModularKernel,
|
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||||
@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
# TODO(bnell): also pass quant_config?
|
|
||||||
def __init__(self, moe: FusedMoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.moe = moe
|
self.moe = moe
|
||||||
self.fused_experts: Optional[Callable] = None
|
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||||
|
self.fused_experts: Optional[FusedMoEModularKernel] = None
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _maybe_make_prepare_finalize(
|
def _maybe_make_prepare_finalize(
|
||||||
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
|
moe: FusedMoEConfig,
|
||||||
|
quant_config: Optional[FusedMoEQuantConfig],
|
||||||
|
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||||
|
|
||||||
|
# TODO: could allow this now
|
||||||
assert not moe.use_flashinfer_cutlass_kernels, \
|
assert not moe.use_flashinfer_cutlass_kernels, \
|
||||||
"Must be created in modelopt.py"
|
"Must be created in modelopt.py"
|
||||||
|
|
||||||
if moe.use_pplx_kernels:
|
if moe.use_pplx_kernels:
|
||||||
|
assert quant_config is not None
|
||||||
|
|
||||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||||
moe.max_num_tokens,
|
moe.max_num_tokens,
|
||||||
moe.hidden_dim,
|
moe.hidden_dim,
|
||||||
moe.in_dtype,
|
moe.in_dtype,
|
||||||
moe.quant_dtype,
|
quant_config.quant_dtype,
|
||||||
per_act_token_quant=moe.per_act_token_quant,
|
per_act_token_quant=quant_config.per_act_token_quant,
|
||||||
block_shape=moe.block_shape,
|
block_shape=quant_config.block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_to_all_args = dict(
|
all_to_all_args = dict(
|
||||||
@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
elif moe.use_deepep_ll_kernels:
|
elif moe.use_deepep_ll_kernels:
|
||||||
|
assert quant_config is not None
|
||||||
all_to_all_args = dict(
|
all_to_all_args = dict(
|
||||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||||
token_hidden_size=moe.hidden_dim,
|
token_hidden_size=moe.hidden_dim,
|
||||||
@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
all2all_manager.world_size)
|
all2all_manager.world_size)
|
||||||
handle = all2all_manager.get_handle(all_to_all_args)
|
handle = all2all_manager.get_handle(all_to_all_args)
|
||||||
|
|
||||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
# Note: We may want to use FP8 dispatch just to reduce
|
||||||
# reduce datamovement
|
# data movement.
|
||||||
use_fp8_dispatch = (moe.quant_config is not None
|
use_fp8_dispatch = (
|
||||||
and moe.quant_config.quant_dtype
|
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||||
== current_platform.fp8_dtype()
|
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE)
|
||||||
and moe.quant_config.block_shape
|
|
||||||
== DEEPEP_QUANT_BLOCK_SHAPE)
|
|
||||||
|
|
||||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||||
handle,
|
handle,
|
||||||
@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
moe: FusedMoEConfig,
|
if self.moe.moe_parallel_config.use_all2all_kernels:
|
||||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
return FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||||
if moe.moe_parallel_config.use_all2all_kernels:
|
self.moe, self.moe_quant_config)
|
||||||
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
# prepare_communication_buffer_for_model.
|
# prepare_communication_buffer_for_model.
|
||||||
def init_prepare_finalize(self, layer: torch.nn.Module):
|
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||||
assert self.moe is not None
|
assert self.moe is not None
|
||||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
|
||||||
|
# We must get the quant config here so that the layer is
|
||||||
|
# completely initialized, i.e. all weights loaded and post
|
||||||
|
# processed.
|
||||||
|
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||||
|
|
||||||
|
prepare_finalize = self.maybe_make_prepare_finalize()
|
||||||
|
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
|
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
|
||||||
@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
assert self.fused_experts is None, \
|
assert self.fused_experts is None, \
|
||||||
f"Attempt to override experts for {id(self)}!"
|
f"Attempt to override experts for {id(self)}!"
|
||||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||||
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
|
experts = self.select_gemm_impl(prepare_finalize, layer)
|
||||||
self.fused_experts = FusedMoEModularKernel(
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
experts,
|
experts,
|
||||||
@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
# based on the all2all implementation, select the appropriate
|
# based on the all2all implementation, select the appropriate
|
||||||
@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
f"{self.__class__.__name__} must select appropriate gemm "
|
f"{self.__class__.__name__} must select appropriate gemm "
|
||||||
"implementation based on the prepare_finalize")
|
"implementation based on the prepare_finalize")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
|
|
||||||
def __init__(self, moe: FusedMoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
super().__init__(moe)
|
super().__init__(moe)
|
||||||
self.has_bias = self.moe.has_bias
|
|
||||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||||
@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
else:
|
else:
|
||||||
self.rocm_aiter_fused_experts = None # type: ignore
|
self.rocm_aiter_fused_experts = None # type: ignore
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return super().maybe_make_prepare_finalize()
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
# TODO(bnell): Remove. Every layer should have an moe config object.
|
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
if (prepare_finalize.activation_format ==
|
if (prepare_finalize.activation_format ==
|
||||||
FusedMoEActivationFormat.BatchedExperts):
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||||
return BatchedTritonExperts(
|
return BatchedTritonExperts(
|
||||||
max_num_tokens=self.moe.max_num_tokens,
|
max_num_tokens=self.moe.max_num_tokens,
|
||||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("TritonExperts %s", self.moe)
|
logger.debug("TritonExperts %s", self.moe)
|
||||||
return TritonExperts()
|
return TritonExperts(self.moe_quant_config)
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
if self.has_bias:
|
if self.moe.has_bias:
|
||||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||||
num_experts,
|
num_experts,
|
||||||
2 * intermediate_size_per_partition,
|
2 * intermediate_size_per_partition,
|
||||||
@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
if self.has_bias:
|
if self.moe.has_bias:
|
||||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
dtype=params_dtype),
|
dtype=params_dtype),
|
||||||
@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
logical_replica_count=logical_replica_count,
|
logical_replica_count=logical_replica_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
if self.moe.has_bias:
|
||||||
|
return biased_moe_quant_config(
|
||||||
|
layer.w13_bias,
|
||||||
|
layer.w2_bias,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
logical_replica_count=logical_replica_count)
|
logical_replica_count=logical_replica_count)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
assert self.fused_experts is None
|
||||||
return self.rocm_aiter_fused_experts(
|
return self.rocm_aiter_fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
elif self.fused_experts is not None:
|
elif self.fused_experts is not None:
|
||||||
if self.has_bias:
|
if self.moe.has_bias:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FusedMoEModularKernel does not support bias.")
|
"FusedMoEModularKernel does not support bias.")
|
||||||
return self.fused_experts(
|
return self.fused_experts(
|
||||||
@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
w1_bias=layer.w13_bias if self.has_bias else None,
|
|
||||||
w2_bias=layer.w2_bias if self.has_bias else None,
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
@ -933,16 +963,18 @@ class FusedMoE(CustomOp):
|
|||||||
# since model_config is not set in the pytest test.
|
# since model_config is not set in the pytest test.
|
||||||
model_dtype = params_dtype
|
model_dtype = params_dtype
|
||||||
|
|
||||||
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
|
moe = FusedMoEConfig(
|
||||||
experts_per_token=top_k,
|
num_experts=self.global_num_experts,
|
||||||
hidden_dim=hidden_size,
|
experts_per_token=top_k,
|
||||||
num_local_experts=self.local_num_experts,
|
hidden_dim=hidden_size,
|
||||||
moe_parallel_config=self.moe_parallel_config,
|
num_local_experts=self.local_num_experts,
|
||||||
in_dtype=model_dtype,
|
moe_parallel_config=self.moe_parallel_config,
|
||||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
in_dtype=model_dtype,
|
||||||
quant_config=quant_config,
|
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||||
has_bias=has_bias)
|
has_bias=has_bias,
|
||||||
|
)
|
||||||
self.moe_config = moe
|
self.moe_config = moe
|
||||||
|
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
# Note: get_quant_method will look at the layer's local_num_experts
|
# Note: get_quant_method will look at the layer's local_num_experts
|
||||||
@ -990,6 +1022,9 @@ class FusedMoE(CustomOp):
|
|||||||
# Chunked all2all staging tensor
|
# Chunked all2all staging tensor
|
||||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# TODO(bnell): flashinfer uses non-batched format.
|
||||||
|
# Does it really need a batched buffer?
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
or self.moe_config.use_flashinfer_cutlass_kernels):
|
or self.moe_config.use_flashinfer_cutlass_kernels):
|
||||||
@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def use_flashinfer_cutlass_kernels(self):
|
def use_flashinfer_cutlass_kernels(self):
|
||||||
return self.moe_config.use_flashinfer_cutlass_kernels
|
return (self.moe_quant_config is not None
|
||||||
|
and self.moe_quant_config.quant_dtype == "nvfp4"
|
||||||
|
and self.moe_config.use_flashinfer_cutlass_kernels)
|
||||||
|
|
||||||
def update_expert_map(self):
|
def update_expert_map(self):
|
||||||
# ep_size and ep_rank should already be updated
|
# ep_size and ep_rank should already be updated
|
||||||
@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp):
|
|||||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||||
|
|
||||||
|
def ensure_moe_quant_config(self):
|
||||||
|
if self.quant_method.moe_quant_config is None:
|
||||||
|
self.quant_method.moe_quant_config = (
|
||||||
|
self.quant_method.get_fused_moe_quant_config(self))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def select_experts(
|
def select_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp):
|
|||||||
assert (
|
assert (
|
||||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||||
|
|
||||||
|
self.ensure_moe_quant_config()
|
||||||
|
|
||||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||||
if self.shared_experts is not None:
|
if self.shared_experts is not None:
|
||||||
full_shared_final_hidden_states = torch.empty_like(
|
full_shared_final_hidden_states = torch.empty_like(
|
||||||
@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
|
self.ensure_moe_quant_config()
|
||||||
|
|
||||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||||
# only when data parallelism (DP) is enabled.
|
# only when data parallelism (DP) is enabled.
|
||||||
use_flashinfer_cutlass_kernels = (
|
_use_flashinfer_cutlass_kernels = (self.dp_size > 1 and
|
||||||
self.dp_size > 1
|
self.use_flashinfer_cutlass_kernels)
|
||||||
and self.moe_config.use_flashinfer_cutlass_kernels)
|
|
||||||
if (self.moe_parallel_config.use_pplx_kernels
|
if (self.moe_parallel_config.use_pplx_kernels
|
||||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||||
or use_flashinfer_cutlass_kernels):
|
or _use_flashinfer_cutlass_kernels):
|
||||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||||
|
|
||||||
do_naive_dispatch_combine: bool = (
|
do_naive_dispatch_combine: bool = (
|
||||||
|
|||||||
@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
"""
|
"""
|
||||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||||
- a1: The (unquantized) input to the MoE layer.
|
- a1: The (unquantized) input to the MoE layer.
|
||||||
- a1_scale: Optional scales for a1
|
|
||||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
|
||||||
sure the quantization is consistent for both gemms.
|
|
||||||
- topk_ids: The topk ids.
|
- topk_ids: The topk ids.
|
||||||
- topk_weights: The topk weights.
|
- topk_weights: The topk weights.
|
||||||
- num_experts: The total number of experts in the global expert space.
|
- num_experts: The total number of experts in the global expert space.
|
||||||
@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
space to the local expert space of the expert parallel shard.
|
space to the local expert space of the expert parallel shard.
|
||||||
- apply_router_weight_on_input: When True, apply the weights to the
|
- apply_router_weight_on_input: When True, apply the weights to the
|
||||||
activations, before quantization + dispatching.
|
activations, before quantization + dispatching.
|
||||||
|
- quant_config: Quantization info provided by the fused experts.
|
||||||
|
|
||||||
Returns a tuple of:
|
Returns a tuple of:
|
||||||
- quantized + dispatched a.
|
- quantized + dispatched a.
|
||||||
- quantized + dispatched a1_scales.
|
- Optional quantized + dispatched a1_scales.
|
||||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||||
as big as the number of local experts with the information about the
|
as big as the number of local experts with the information about the
|
||||||
number of tokens assigned to each local expert.
|
number of tokens assigned to each local expert.
|
||||||
@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
def prepare_async(
|
def prepare_async(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add supported activations method (return string)
|
||||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||||
"""
|
"""
|
||||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||||
@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: Optional[FusedMoEQuantConfig],
|
quant_config: FusedMoEQuantConfig,
|
||||||
):
|
):
|
||||||
if quant_config is not None:
|
"""
|
||||||
self.quant_config = quant_config
|
quant_config: Quantization parameters for this experts instance.
|
||||||
else:
|
"""
|
||||||
self.quant_config = FusedMoEQuantConfig()
|
self.quant_config = quant_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
#
|
||||||
|
# Various helpers for accessing quantization parameters from the
|
||||||
|
# quant_config.
|
||||||
|
#
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||||
return self.quant_config.quant_dtype
|
return self.quant_config.quant_dtype
|
||||||
@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
def per_out_ch_quant(self) -> bool:
|
def per_out_ch_quant(self) -> bool:
|
||||||
return self.quant_config.per_out_ch_quant
|
return self.quant_config.per_out_ch_quant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.a1_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.a2_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.a1_gscale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.a2_gscale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.w1_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.w2_scale
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.w1_zp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.w2_zp
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.w1_bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.w2_bias
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.g1_alphas
|
||||||
|
|
||||||
|
@property
|
||||||
|
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.quant_config.g2_alphas
|
||||||
|
|
||||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
from the global expert space to the local expert space of the expert
|
from the global expert space to the local expert space of the expert
|
||||||
parallel shard.
|
parallel shard.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
|
||||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
|
||||||
w1.
|
|
||||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
|
||||||
w2.
|
|
||||||
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
|
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
|
||||||
used for a1.
|
used for a1. Result of quantization from prepare/finalize and not
|
||||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
from the FusedMoEQuantConfig.
|
||||||
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
|
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
|
||||||
must be large enough to hold output of either MoE gemm.
|
must be large enough to hold output of either MoE gemm.
|
||||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||||
@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=w1_zp,
|
|
||||||
w2_zp=w2_zp,
|
|
||||||
a1q_scale=a1q_scale,
|
a1q_scale=a1q_scale,
|
||||||
a2_scale=a2_scale,
|
|
||||||
workspace13=workspace13,
|
workspace13=workspace13,
|
||||||
workspace2=workspace2,
|
workspace2=workspace2,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
local_num_experts=local_num_experts,
|
local_num_experts=local_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=w1_zp,
|
|
||||||
w2_zp=w2_zp,
|
|
||||||
a1q_scale=a1q_scale,
|
a1q_scale=a1q_scale,
|
||||||
a2_scale=a2_scale,
|
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||||
s = chunk_idx * CHUNK_SIZE
|
s = chunk_idx * CHUNK_SIZE
|
||||||
e = min(s + CHUNK_SIZE, M)
|
e = min(s + CHUNK_SIZE, M)
|
||||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
return (
|
||||||
_chunk_scales(a2_scale, s,
|
a1q[s:e],
|
||||||
e), topk_ids[s:e], topk_weights[s:e])
|
_chunk_scales(a1q_scale, s, e),
|
||||||
|
_chunk_scales(self.fused_experts.a2_scale, s, e),
|
||||||
|
topk_ids[s:e],
|
||||||
|
topk_weights[s:e],
|
||||||
|
)
|
||||||
|
|
||||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||||
assert fused_out.size(0) % M == 0, (
|
assert fused_out.size(0) % M == 0, (
|
||||||
@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
local_num_experts=local_num_experts,
|
local_num_experts=local_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=w1_zp,
|
|
||||||
w2_zp=w2_zp,
|
|
||||||
a1q_scale=c_a1q_scale,
|
a1q_scale=c_a1q_scale,
|
||||||
a2_scale=c_a2_scale,
|
|
||||||
expert_tokens_meta=c_expert_tokens_meta,
|
expert_tokens_meta=c_expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
from the global expert space to the local expert space of the expert
|
from the global expert space to the local expert space of the expert
|
||||||
parallel shard.
|
parallel shard.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
|
||||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
|
||||||
w1.
|
|
||||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
|
||||||
w2.
|
|
||||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
|
|
||||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
|
||||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
applied directly on the inputs. This is only applicable when topk is
|
applied directly on the inputs. This is only applicable when topk is
|
||||||
1.
|
1.
|
||||||
@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||||
a1,
|
a1,
|
||||||
a1_scale,
|
|
||||||
a2_scale,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
global_num_experts,
|
global_num_experts,
|
||||||
@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
dbo_maybe_run_recv_hook()
|
dbo_maybe_run_recv_hook()
|
||||||
hook, receiver = self.prepare_finalize.prepare_async(
|
hook, receiver = self.prepare_finalize.prepare_async(
|
||||||
a1,
|
a1,
|
||||||
a1_scale,
|
|
||||||
a2_scale,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
global_num_experts,
|
global_num_experts,
|
||||||
@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
local_num_experts=local_num_experts,
|
local_num_experts=local_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
w1_zp=w1_zp,
|
|
||||||
w2_zp=w2_zp,
|
|
||||||
a1q_scale=a1q_scale,
|
a1q_scale=a1q_scale,
|
||||||
a2_scale=a2_scale,
|
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare_async(
|
def prepare_async(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
repeat_cols = 4
|
repeat_cols = 4
|
||||||
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
||||||
|
# TODO(bnell): always pass quant_config.a1_scale?
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, (None if quant_config.per_act_token_quant else a1_scale),
|
a1, (None if quant_config.per_act_token_quant else
|
||||||
|
quant_config.a1_scale),
|
||||||
quant_dtype=quant_config.quant_dtype,
|
quant_dtype=quant_config.quant_dtype,
|
||||||
per_act_token_quant=quant_config.per_act_token_quant,
|
per_act_token_quant=quant_config.per_act_token_quant,
|
||||||
block_shape=quant_config.block_shape)
|
block_shape=quant_config.block_shape)
|
||||||
@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
) -> mk.PrepareResultType:
|
) -> mk.PrepareResultType:
|
||||||
hook, receiver = self.prepare_async(
|
hook, receiver = self.prepare_async(
|
||||||
a1,
|
a1,
|
||||||
a1_scale,
|
|
||||||
a2_scale,
|
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
|
|||||||
@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
a1: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
a1_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1.mul_(topk_weights.to(a1.dtype))
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, a1_scale, quant_config.quant_dtype,
|
a1, quant_config.a1_scale, quant_config.quant_dtype,
|
||||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
|
|
||||||
return a1q, a1q_scale, None, None, None
|
return a1q, a1q_scale, None, None, None
|
||||||
|
|||||||
@ -7,6 +7,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk(
|
|||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fused_experts(
|
def rocm_aiter_fused_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
per_channel_quant: bool = False,
|
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
) -> torch.Tensor:
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
if quant_config is None:
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
|
|
||||||
activation_method = (ActivationMethod.SILU
|
activation_method = (ActivationMethod.SILU
|
||||||
if activation == "silu" else ActivationMethod.GELU)
|
if activation == "silu" else ActivationMethod.GELU)
|
||||||
@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
|
|||||||
expert_mask = None
|
expert_mask = None
|
||||||
|
|
||||||
# w8a8 per-channel quantization
|
# w8a8 per-channel quantization
|
||||||
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
if (quant_config.per_act_token_quant and apply_router_weight_on_input
|
||||||
|
and quant_config.use_fp8_w8a8):
|
||||||
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||||
# This applies topk_weights on the GEMM output of the first FC layer
|
# This applies topk_weights on the GEMM output of the first FC layer
|
||||||
# rather than the second FC.
|
# rather than the second FC.
|
||||||
@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
|
|||||||
w2,
|
w2,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
fc1_scale=w1_scale,
|
fc1_scale=quant_config.w1_scale,
|
||||||
fc2_scale=w2_scale,
|
fc2_scale=quant_config.w2_scale,
|
||||||
fc1_smooth_scale=None,
|
fc1_smooth_scale=None,
|
||||||
fc2_smooth_scale=None,
|
fc2_smooth_scale=None,
|
||||||
a16=False,
|
a16=False,
|
||||||
@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
|
|||||||
quant_method = QuantMethod.NO.value
|
quant_method = QuantMethod.NO.value
|
||||||
|
|
||||||
# w8a8 block-scaled
|
# w8a8 block-scaled
|
||||||
if block_shape is not None and use_fp8_w8a8:
|
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
|
||||||
assert not apply_router_weight_on_input, (
|
assert not apply_router_weight_on_input, (
|
||||||
"apply_router_weight_on_input is\
|
"apply_router_weight_on_input is\
|
||||||
not supported for block scaled moe")
|
not supported for block scaled moe")
|
||||||
assert w1_scale is not None
|
assert quant_config.w1_scale is not None
|
||||||
assert w2_scale is not None
|
assert quant_config.w2_scale is not None
|
||||||
quant_method = QuantMethod.BLOCK_128x128.value
|
quant_method = QuantMethod.BLOCK_128x128.value
|
||||||
elif use_fp8_w8a8:
|
elif quant_config.use_fp8_w8a8:
|
||||||
# Currently only per tensor quantization method is enabled.
|
# Currently only per tensor quantization method is enabled.
|
||||||
quant_method = QuantMethod.PER_TENSOR.value
|
quant_method = QuantMethod.PER_TENSOR.value
|
||||||
|
|
||||||
@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
|
|||||||
expert_mask=expert_mask,
|
expert_mask=expert_mask,
|
||||||
quant_method=quant_method,
|
quant_method=quant_method,
|
||||||
activation_method=activation_method,
|
activation_method=activation_method,
|
||||||
w1_scale=w1_scale,
|
w1_scale=quant_config.w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=quant_config.w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=quant_config.a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=quant_config.a2_scale,
|
||||||
doweight_stage1=apply_router_weight_on_input)
|
doweight_stage1=apply_router_weight_on_input)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
|
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||||
|
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||||
deep_gemm_block_shape)
|
deep_gemm_block_shape)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_fp8_w8a8: bool = False,
|
quant_config: FusedMoEQuantConfig,
|
||||||
use_int8_w8a8: bool = False,
|
|
||||||
use_int8_w8a16: bool = False,
|
|
||||||
use_int4_w4a16: bool = False,
|
|
||||||
use_mxfp4_w4a4: bool = False,
|
|
||||||
per_act_token_quant: bool = False,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
allow_deep_gemm: bool = False,
|
allow_deep_gemm: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(quant_config)
|
||||||
FusedMoEQuantConfig.make(
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
))
|
|
||||||
self.triton_expert = TritonExperts(
|
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
|
||||||
use_int8_w8a8=use_int8_w8a8,
|
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
|
||||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
|
||||||
per_act_token_quant=per_act_token_quant,
|
|
||||||
block_shape=block_shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
|
self.triton_expert = TritonExperts(quant_config)
|
||||||
|
|
||||||
|
self.allow_deep_gemm = (allow_deep_gemm
|
||||||
|
and self.quant_config.use_fp8_w8a8 and
|
||||||
self.block_shape == deep_gemm_block_shape())
|
self.block_shape == deep_gemm_block_shape())
|
||||||
|
|
||||||
self.deep_gemm_expert = DeepGemmExperts(
|
self.deep_gemm_expert = DeepGemmExperts(
|
||||||
) if self.allow_deep_gemm else None
|
self.quant_config) if self.allow_deep_gemm else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation,
|
activation,
|
||||||
global_num_experts,
|
global_num_experts,
|
||||||
expert_map,
|
expert_map,
|
||||||
w1_scale,
|
|
||||||
w2_scale,
|
|
||||||
w1_zp,
|
|
||||||
w2_zp,
|
|
||||||
a1q_scale,
|
a1q_scale,
|
||||||
a2_scale,
|
|
||||||
workspace13,
|
workspace13,
|
||||||
workspace2,
|
workspace2,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
|
|||||||
@ -5,7 +5,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceNoOP)
|
TopKWeightAndReduceNoOP)
|
||||||
from vllm.utils import next_power_of_2
|
from vllm.utils import next_power_of_2
|
||||||
@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
|
quant_config: FusedMoEQuantConfig,
|
||||||
gemm1_alpha,
|
gemm1_alpha,
|
||||||
gemm1_beta,
|
gemm1_beta,
|
||||||
gemm1_clamp_limit,
|
gemm1_clamp_limit,
|
||||||
w13_bias,
|
|
||||||
w2_bias,
|
|
||||||
max_capture_size,
|
max_capture_size,
|
||||||
):
|
):
|
||||||
super().__init__(moe.quant_config)
|
super().__init__(quant_config)
|
||||||
self.moe = moe
|
self.moe = moe
|
||||||
self.gemm1_alpha = gemm1_alpha
|
self.gemm1_alpha = gemm1_alpha
|
||||||
self.gemm1_beta = gemm1_beta
|
self.gemm1_beta = gemm1_beta
|
||||||
self.gemm1_clamp_limit = gemm1_clamp_limit
|
self.gemm1_clamp_limit = gemm1_clamp_limit
|
||||||
self.w13_bias = w13_bias
|
|
||||||
self.w2_bias = w2_bias
|
|
||||||
self.max_capture_size = max_capture_size
|
self.max_capture_size = max_capture_size
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation: str,
|
activation: str,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor],
|
|
||||||
w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor],
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
a1q_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
|
||||||
workspace13: torch.Tensor,
|
workspace13: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||||
torch.bfloat16).view(torch.int16)
|
torch.bfloat16).view(torch.int16)
|
||||||
|
|
||||||
assert w1_scale is not None
|
assert self.w1_scale is not None
|
||||||
assert w2_scale is not None
|
assert self.w2_scale is not None
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"topk_ids":
|
"topk_ids":
|
||||||
packed_tensor,
|
packed_tensor,
|
||||||
@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"gemm1_weights":
|
"gemm1_weights":
|
||||||
w1,
|
w1,
|
||||||
"gemm1_weights_scale":
|
"gemm1_weights_scale":
|
||||||
w1_scale,
|
self.w1_scale,
|
||||||
"gemm1_bias":
|
"gemm1_bias":
|
||||||
self.w13_bias,
|
self.w1_bias,
|
||||||
"gemm1_alpha":
|
"gemm1_alpha":
|
||||||
self.gemm1_alpha,
|
self.gemm1_alpha,
|
||||||
"gemm1_beta":
|
"gemm1_beta":
|
||||||
@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"gemm2_weights":
|
"gemm2_weights":
|
||||||
w2,
|
w2,
|
||||||
"gemm2_weights_scale":
|
"gemm2_weights_scale":
|
||||||
w2_scale,
|
self.w2_scale,
|
||||||
"gemm2_bias":
|
"gemm2_bias":
|
||||||
self.w2_bias,
|
self.w2_bias,
|
||||||
"output1_scale_scalar":
|
"output1_scale_scalar":
|
||||||
|
|||||||
@ -268,3 +268,7 @@ def _validate_scale_shape(
|
|||||||
assert block_shape is not None
|
assert block_shape is not None
|
||||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||||
|
|
||||||
|
|
||||||
|
def activation_without_mul(activation: str) -> str:
|
||||||
|
return activation + "_no_mul"
|
||||||
|
|||||||
@ -9,8 +9,10 @@ from torch.nn import Parameter
|
|||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||||
UnquantizedFusedMoEMethod)
|
UnquantizedFusedMoEMethod)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
FusedMoEConfig,
|
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_weights_4bit(
|
def _create_weights_4bit(
|
||||||
|
|||||||
@ -16,8 +16,11 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported)
|
||||||
FusedMoeWeightScaleSupported)
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||||
|
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
||||||
|
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
is_valid_flashinfer_cutlass_fused_moe)
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
return CompressedTensorsWNA16MarlinMoEMethod(
|
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||||
quant_config, layer.moe_config)
|
quant_config, layer.moe_config)
|
||||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
|
return CompressedTensorsW4A4MoeMethod(layer.moe_config)
|
||||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||||
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
||||||
@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
detect_nvfp4_moe_support)
|
detect_nvfp4_moe_support)
|
||||||
super().__init__(moe)
|
super().__init__(moe)
|
||||||
@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
self.layer = layer
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
(layer.w2_input_global_scale), requires_grad=False)
|
(layer.w2_input_global_scale), requires_grad=False)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
moe: FusedMoEConfig,
|
if self.use_marlin:
|
||||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
return None
|
||||||
if not self.allow_flashinfer:
|
elif not self.allow_flashinfer:
|
||||||
return super().maybe_make_prepare_finalize(moe)
|
return super().maybe_make_prepare_finalize()
|
||||||
|
|
||||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||||
moe,
|
self.moe)
|
||||||
a1_gscale=self.layer.w13_input_scale_quant,
|
|
||||||
)
|
|
||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
"""Return the appropriate GEMM experts implementation."""
|
"""Return the appropriate GEMM experts implementation."""
|
||||||
experts = select_nvfp4_gemm_impl(
|
experts = select_nvfp4_gemm_impl(
|
||||||
moe,
|
self.moe,
|
||||||
g1_alphas=self.layer.g1_alphas,
|
self.moe_quant_config,
|
||||||
g2_alphas=self.layer.g2_alphas,
|
|
||||||
a1_gscale=self.layer.w13_input_scale_quant,
|
|
||||||
a2_gscale=self.layer.w2_input_scale_quant,
|
|
||||||
allow_flashinfer=self.allow_flashinfer,
|
allow_flashinfer=self.allow_flashinfer,
|
||||||
)
|
)
|
||||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
return experts
|
return experts
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
if self.use_marlin:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return nvfp4_moe_quant_config(
|
||||||
|
g1_alphas=layer.g1_alphas,
|
||||||
|
g2_alphas=layer.g2_alphas,
|
||||||
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=layer.w2_input_scale_quant,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
assert self.fused_experts is None
|
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError("EPLB not supported for "
|
raise NotImplementedError("EPLB not supported for "
|
||||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||||
@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
indices_type=self.topk_indices_dtype,
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Note: the order here is important. self.fused_experts can override
|
||||||
|
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
|
||||||
|
#
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
|
assert self.fused_experts is None
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
workspace=layer.workspace)
|
workspace=layer.workspace)
|
||||||
|
|
||||||
# FlashInfer fused experts path
|
elif self.fused_experts is not None:
|
||||||
if self.fused_experts is not None:
|
|
||||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||||
x, layer.w13_weight, layer.w2_weight), (
|
x, layer.w13_weight, layer.w2_weight), (
|
||||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||||
@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# FlashInfer fused experts path
|
||||||
elif self.allow_flashinfer:
|
elif self.allow_flashinfer:
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||||
flashinfer_cutlass_moe_fp4)
|
flashinfer_cutlass_moe_fp4)
|
||||||
@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
x, layer.w13_weight, layer.w2_weight), (
|
x, layer.w13_weight, layer.w2_weight), (
|
||||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||||
|
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
|
||||||
return flashinfer_cutlass_moe_fp4(
|
return flashinfer_cutlass_moe_fp4(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
g1_alphas=layer.g1_alphas,
|
|
||||||
g2_alphas=layer.g2_alphas,
|
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
|
||||||
a2_gscale=layer.w2_input_scale_quant,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp4)
|
||||||
|
|
||||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
"is currently not supported for "
|
"is currently not supported for "
|
||||||
"CompressedTensorsW4A4MoeMethod.")
|
"CompressedTensorsW4A4MoeMethod.")
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
assert self.moe_quant_config is not None
|
||||||
cutlass_moe_fp4)
|
|
||||||
|
|
||||||
# Cutlass moe takes in activations in BF16/Half precision
|
# Cutlass moe takes in activations in BF16/Half precision
|
||||||
# and fp4 quantized weights loaded from the checkpoint
|
# and fp4 quantized weights loaded from the checkpoint
|
||||||
return cutlass_moe_fp4(
|
return cutlass_moe_fp4(
|
||||||
a=x,
|
a=x,
|
||||||
w1_fp4=layer.w13_weight,
|
w1_fp4=layer.w13_weight,
|
||||||
w2_fp4=layer.w2_weight,
|
w2_fp4=layer.w2_weight,
|
||||||
w1_blockscale=layer.w13_weight_scale,
|
topk_weights=topk_weights,
|
||||||
w2_blockscale=layer.w2_weight_scale,
|
topk_ids=topk_ids,
|
||||||
g1_alphas=layer.g1_alphas,
|
quant_config=self.moe_quant_config,
|
||||||
g2_alphas=layer.g2_alphas,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
# TODO(bnell): derive these from arguments
|
||||||
a2_gscale=layer.w2_input_scale_quant,
|
m=x.shape[0],
|
||||||
topk_weights=topk_weights,
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
topk_ids=topk_ids,
|
k=x.shape[1],
|
||||||
m=x.shape[0],
|
e=layer.w13_weight.shape[0],
|
||||||
n=layer.w2_weight.shape[2] * 2,
|
).to(x.dtype)
|
||||||
k=x.shape[1],
|
|
||||||
e=layer.w13_weight.shape[0],
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
|
||||||
x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
|
||||||
elif self.use_marlin:
|
elif self.use_marlin:
|
||||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
del layer.w13_input_scale
|
del layer.w13_input_scale
|
||||||
del layer.w2_input_scale
|
del layer.w2_input_scale
|
||||||
self.fused_experts_func = None
|
|
||||||
else:
|
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
self.fused_experts_func = fused_experts
|
|
||||||
|
|
||||||
if self.use_cutlass:
|
if self.use_cutlass:
|
||||||
device = layer.w13_weight.device
|
device = layer.w13_weight.device
|
||||||
@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
|
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return super().maybe_make_prepare_finalize()
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
self,
|
||||||
moe: FusedMoEConfig,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
|
layer: torch.nn.Module,
|
||||||
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
# cutlass path
|
# cutlass path
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
if self.use_cutlass:
|
if self.use_cutlass:
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
|
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
|
||||||
@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
logger.debug("CutlassBatchedExpertsFp8(%s)",
|
logger.debug("CutlassBatchedExpertsFp8(%s)",
|
||||||
self.__class__.__name__)
|
self.__class__.__name__)
|
||||||
experts = CutlassBatchedExpertsFp8(
|
experts = CutlassBatchedExpertsFp8(
|
||||||
moe.num_local_experts,
|
self.moe.num_local_experts,
|
||||||
num_dispatchers,
|
num_dispatchers,
|
||||||
moe.in_dtype,
|
self.moe.in_dtype,
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
|
||||||
ab_strides1=self.ab_strides1_c_strides2,
|
ab_strides1=self.ab_strides1_c_strides2,
|
||||||
ab_strides2=self.ab_strides2,
|
ab_strides2=self.ab_strides2,
|
||||||
c_strides1=self.c_strides1,
|
c_strides1=self.c_strides1,
|
||||||
c_strides2=self.ab_strides1_c_strides2,
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||||
experts = CutlassExpertsFp8(
|
experts = CutlassExpertsFp8(
|
||||||
moe.in_dtype,
|
self.moe.in_dtype,
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
|
||||||
ab_strides1=self.ab_strides1_c_strides2,
|
ab_strides1=self.ab_strides1_c_strides2,
|
||||||
ab_strides2=self.ab_strides2,
|
ab_strides2=self.ab_strides2,
|
||||||
c_strides1=self.c_strides1,
|
c_strides1=self.c_strides1,
|
||||||
c_strides2=self.ab_strides1_c_strides2,
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.disable_expert_map = (num_dispatchers > 1
|
self.disable_expert_map = (num_dispatchers > 1
|
||||||
@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
|
|
||||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||||
|
|
||||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
|
||||||
|
|
||||||
if (prepare_finalize.activation_format ==
|
if (prepare_finalize.activation_format ==
|
||||||
FusedMoEActivationFormat.BatchedExperts):
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
|
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
|
||||||
)
|
)
|
||||||
assert max_num_tokens_per_rank is not None
|
assert max_num_tokens_per_rank is not None
|
||||||
|
|
||||||
|
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||||
return BatchedTritonExperts(
|
return BatchedTritonExperts(
|
||||||
max_num_tokens=max_num_tokens_per_rank,
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
use_fp8_w8a8=True,
|
quant_config=self.moe_quant_config,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
per_act_token_quant=(
|
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return TritonExperts(
|
logger.debug("TritonExperts(%s)", self.__class__.__name__)
|
||||||
use_fp8_w8a8=True,
|
return TritonExperts(self.moe_quant_config)
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
per_act_token_quant=(
|
def get_fused_moe_quant_config(
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN),
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
)
|
if self.use_marlin:
|
||||||
|
return None
|
||||||
|
|
||||||
|
per_act_token = (
|
||||||
|
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||||
|
per_channel_quant = (
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||||
|
|
||||||
|
return fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_channel_quant,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -841,92 +861,19 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
indices_type=self.topk_indices_dtype,
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# cutlass path
|
per_act_token = (
|
||||||
if self.use_cutlass:
|
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||||
per_act_token = (
|
per_channel_quant = (
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||||
per_channel_quant = (
|
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
|
||||||
|
|
||||||
# small-batch fallback on SM100
|
#
|
||||||
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
# Note: the order here is important. self.fused_experts can override
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
# cutlass fp8 or fused_experts but not marlin or rocm.
|
||||||
return fused_experts(
|
#
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=per_channel_quant,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale)
|
|
||||||
|
|
||||||
if self.fused_experts is None:
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
||||||
cutlass_moe_fp8)
|
|
||||||
return cutlass_moe_fp8(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
per_act_token=per_act_token,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
ab_strides1=self.ab_strides1_c_strides2,
|
|
||||||
ab_strides2=self.ab_strides2,
|
|
||||||
c_strides1=self.c_strides1,
|
|
||||||
c_strides2=self.ab_strides1_c_strides2,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self.fused_experts(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=None if self.disable_expert_map else expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
|
||||||
return self.rocm_aiter_fused_experts_func(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
activation=activation,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=self.weight_quant.strategy ==
|
|
||||||
QuantizationStrategy.CHANNEL,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
expert_map=expert_map)
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
assert self.fused_experts is None
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -944,26 +891,95 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
workspace=layer.workspace)
|
workspace=layer.workspace)
|
||||||
|
|
||||||
assert self.fused_experts_func is not None
|
elif self.rocm_aiter_moe_enabled:
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||||
|
rocm_aiter_fused_experts)
|
||||||
|
assert per_act_token == per_channel_quant
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
assert self.fused_experts is None
|
||||||
|
return rocm_aiter_fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=expert_map,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
return self.fused_experts_func(
|
elif self.fused_experts is not None:
|
||||||
hidden_states=x,
|
return self.fused_experts(
|
||||||
w1=layer.w13_weight,
|
x,
|
||||||
w2=layer.w2_weight,
|
layer.w13_weight,
|
||||||
topk_weights=topk_weights,
|
layer.w2_weight,
|
||||||
topk_ids=topk_ids,
|
topk_weights,
|
||||||
inplace=True,
|
topk_ids,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
global_num_experts=global_num_experts,
|
||||||
use_fp8_w8a8=True,
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
per_channel_quant=self.weight_quant.strategy ==
|
)
|
||||||
QuantizationStrategy.CHANNEL,
|
|
||||||
global_num_experts=global_num_experts,
|
# cutlass path
|
||||||
expert_map=expert_map,
|
elif self.use_cutlass:
|
||||||
w1_scale=layer.w13_weight_scale,
|
assert self.moe_quant_config is not None
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
# small-batch fallback on SM100
|
||||||
a2_scale=layer.w2_input_scale)
|
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
assert per_act_token == per_channel_quant
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp8)
|
||||||
|
assert per_act_token == per_channel_quant
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
return cutlass_moe_fp8(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=None if self.disable_expert_map else expert_map,
|
||||||
|
ab_strides1=self.ab_strides1_c_strides2,
|
||||||
|
ab_strides2=self.ab_strides2,
|
||||||
|
c_strides1=self.c_strides1,
|
||||||
|
c_strides2=self.ab_strides1_c_strides2,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
assert per_act_token == per_channel_quant
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return int8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
per_act_token_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_int8_w8a8=True,
|
|
||||||
per_channel_quant=True,
|
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
quant_config=self.moe_quant_config,
|
||||||
w2_scale=layer.w2_weight_scale,
|
)
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale)
|
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||||
@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
|
|
||||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_weight_scale.transpose(1, 2).contiguous(),
|
layer.w2_weight_scale.transpose(1, 2).contiguous(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
assert self.num_bits == 4 or self.num_bits == 8
|
||||||
|
config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4
|
||||||
|
else int8_w8a16_moe_quant_config)
|
||||||
|
|
||||||
|
return config_builder(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
w1_zp=None,
|
||||||
|
w2_zp=None,
|
||||||
|
block_shape=[0, self.group_size],
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_int4_w4a16=self.num_bits == 4,
|
|
||||||
use_int8_w8a16=self.num_bits == 8,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
quant_config=self.moe_quant_config,
|
||||||
w2_scale=layer.w2_weight_scale,
|
)
|
||||||
w1_zp=None,
|
|
||||||
w2_zp=None,
|
|
||||||
block_shape=[0, self.group_size])
|
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import torch
|
|||||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.register_parameter("w2_scale", w2_scale)
|
layer.register_parameter("w2_scale", w2_scale)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
|
||||||
|
w2_scale=layer.w2_scale,
|
||||||
|
w1_zp=None,
|
||||||
|
w2_zp=None)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_int8_w8a16=True,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_scale,
|
quant_config=self.moe_quant_config,
|
||||||
w2_scale=layer.w2_scale)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def quantizing_weight_loader(layer, weight_loader):
|
def quantizing_weight_loader(layer, weight_loader):
|
||||||
|
|||||||
@ -14,9 +14,11 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
|
||||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||||
"platform.")
|
"platform.")
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
|
||||||
self,
|
|
||||||
moe: FusedMoEConfig,
|
|
||||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
|
||||||
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
|
||||||
return super().maybe_make_prepare_finalize(moe)
|
|
||||||
|
|
||||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
|
||||||
moe,
|
|
||||||
layer=self.layer,
|
|
||||||
)
|
|
||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
|
||||||
return prepare_finalize
|
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||||
layer.w2_weight_scale_inv)
|
layer.w2_weight_scale_inv)
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
|
if (self.rocm_aiter_moe_enabled or self.use_marlin
|
||||||
|
or self.flashinfer_moe_backend
|
||||||
|
== FlashinferMoeBackend.TENSORRT_LLM):
|
||||||
|
return None
|
||||||
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
|
prepare_finalize = (
|
||||||
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
|
||||||
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
|
return prepare_finalize
|
||||||
|
else:
|
||||||
|
return super().maybe_make_prepare_finalize()
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||||
|
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
|
||||||
if (prepare_finalize.activation_format ==
|
if (prepare_finalize.activation_format ==
|
||||||
FusedMoEActivationFormat.BatchedExperts):
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
max_num_tokens_per_rank = (
|
max_num_tokens_per_rank = (
|
||||||
@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
return BatchedTritonOrDeepGemmExperts(
|
return BatchedTritonOrDeepGemmExperts(
|
||||||
max_num_tokens=max_num_tokens_per_rank,
|
max_num_tokens=max_num_tokens_per_rank,
|
||||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||||
use_fp8_w8a8=True,
|
quant_config=self.moe_quant_config,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
per_act_token_quant=False,
|
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
experts = select_cutlass_fp8_gemm_impl(
|
experts = select_cutlass_fp8_gemm_impl(
|
||||||
moe,
|
self.moe,
|
||||||
self.layer,
|
self.moe_quant_config,
|
||||||
)
|
)
|
||||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
return experts
|
return experts
|
||||||
@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
self.__class__.__name__, self.quant_config.weight_block_size,
|
self.__class__.__name__, self.quant_config.weight_block_size,
|
||||||
False)
|
False)
|
||||||
return TritonOrDeepGemmExperts(
|
return TritonOrDeepGemmExperts(
|
||||||
use_fp8_w8a8=True,
|
quant_config=self.moe_quant_config,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
if self.use_marlin:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=(layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant else layer.w13_weight_scale),
|
||||||
|
w2_scale=(layer.w2_weight_scale_inv
|
||||||
|
if self.block_quant else layer.w2_weight_scale),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert logical_replica_count is not None
|
assert logical_replica_count is not None
|
||||||
assert isinstance(layer, FusedMoE)
|
assert isinstance(layer, FusedMoE)
|
||||||
|
|
||||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
and self.fused_experts is None):
|
||||||
assert activation == 'silu', (
|
assert activation == 'silu', (
|
||||||
f"Expected 'silu' activation but got {activation}")
|
f"Expected 'silu' activation but got {activation}")
|
||||||
assert scoring_func == 'sigmoid', (
|
assert scoring_func == 'sigmoid', (
|
||||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
|
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||||
assert (renormalize and use_grouped_topk
|
assert (renormalize and use_grouped_topk
|
||||||
and custom_routing_function is None)
|
and custom_routing_function is None)
|
||||||
|
|
||||||
@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logical_replica_count=logical_replica_count,
|
logical_replica_count=logical_replica_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Note: the order of checks is important since self.fused_experts
|
||||||
|
# can override fused_experts or cutlass but not rocm or marlin.
|
||||||
|
#
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
rocm_aiter_fused_experts)
|
rocm_aiter_fused_experts)
|
||||||
|
assert self.fused_experts is None
|
||||||
return rocm_aiter_fused_experts(
|
return rocm_aiter_fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_fp8_w8a8=True,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
w1_scale=(layer.w13_weight_scale_inv
|
expert_map=expert_map,
|
||||||
if self.block_quant else layer.w13_weight_scale),
|
quant_config=self.moe_quant_config)
|
||||||
w2_scale=(layer.w2_weight_scale_inv
|
|
||||||
if self.block_quant else layer.w2_weight_scale),
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
expert_map=expert_map)
|
|
||||||
elif self.use_marlin:
|
elif self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
assert self.fused_experts is None
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -1105,40 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
workspace=layer.workspace)
|
workspace=layer.workspace)
|
||||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
elif self.fused_experts:
|
||||||
assert self.block_quant is None
|
return self.fused_experts(
|
||||||
assert (not renormalize and custom_routing_function is not None)
|
|
||||||
assert activation == 'silu', (
|
|
||||||
f"Expected 'silu' activation but got {activation}")
|
|
||||||
assert scoring_func == 'sigmoid', (
|
|
||||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
|
||||||
if self.fused_experts is not None:
|
|
||||||
return self.fused_experts(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=False,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return flashinfer_cutlass_moe_fp8(
|
|
||||||
x,
|
|
||||||
layer,
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
inplace=False,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
common_kwargs = dict(
|
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@ -1149,26 +1133,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=(layer.w13_weight_scale_inv
|
|
||||||
if self.block_quant else layer.w13_weight_scale),
|
|
||||||
w2_scale=(layer.w2_weight_scale_inv
|
|
||||||
if self.block_quant else layer.w2_weight_scale),
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
)
|
||||||
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
|
assert self.block_quant is None
|
||||||
|
assert (not renormalize and custom_routing_function is not None)
|
||||||
|
assert activation == 'silu', (
|
||||||
|
f"Expected 'silu' activation but got {activation}")
|
||||||
|
assert scoring_func == 'sigmoid', (
|
||||||
|
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||||
|
|
||||||
if self.fused_experts is not None:
|
return flashinfer_cutlass_moe_fp8(
|
||||||
return self.fused_experts(**common_kwargs)
|
x,
|
||||||
else:
|
layer,
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
topk_weights,
|
||||||
return fused_experts(
|
topk_ids,
|
||||||
**common_kwargs,
|
inplace=False,
|
||||||
use_fp8_w8a8=True,
|
activation=activation,
|
||||||
block_shape=self.quant_config.weight_block_size,
|
global_num_experts=global_num_experts,
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
expert_map=expert_map,
|
||||||
allow_cutlass_block_scaled_grouped_gemm=(
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
self.allow_cutlass_block_scaled_grouped_gemm),
|
)
|
||||||
)
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=expert_map,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
|
allow_cutlass_block_scaled_grouped_gemm=(
|
||||||
|
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
FusedMoEConfig,
|
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
|||||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -9,8 +9,10 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||||
UnquantizedFusedMoEMethod)
|
UnquantizedFusedMoEMethod)
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm._ipex_ops import ipex_ops as ops
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
use_prepack=True,
|
use_prepack=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return None
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -11,7 +11,9 @@ import vllm.envs as envs
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||||
|
nvfp4_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
is_valid_flashinfer_cutlass_fused_moe)
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
cutlass_fp8_supported)
|
cutlass_fp8_supported)
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||||
self.fused_experts: Optional[
|
|
||||||
mk.FusedMoEModularKernel] = None # type: ignore
|
|
||||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
moe: FusedMoEConfig,
|
# TRT LLM not supported with all2all yet.
|
||||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
if self.fused_experts is not None or \
|
return None
|
||||||
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
return super().maybe_make_prepare_finalize(moe)
|
prepare_finalize = (
|
||||||
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
|
||||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
moe,
|
return prepare_finalize
|
||||||
layer=self.layer,
|
else:
|
||||||
)
|
return super().maybe_make_prepare_finalize()
|
||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
|
||||||
return prepare_finalize
|
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
experts = select_cutlass_fp8_gemm_impl(
|
experts = select_cutlass_fp8_gemm_impl(
|
||||||
moe,
|
self.moe,
|
||||||
self.layer,
|
self.moe_quant_config,
|
||||||
)
|
)
|
||||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
return experts
|
return experts
|
||||||
@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||||
layer.w2_weight)
|
layer.w2_weight)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||||
|
|
||||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
assert self.fused_experts is None
|
||||||
assert activation == 'silu', (
|
assert activation == 'silu', (
|
||||||
f"Expected 'silu' activation but got {activation}")
|
f"Expected 'silu' activation but got {activation}")
|
||||||
assert not renormalize
|
assert not renormalize
|
||||||
@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
indices_type=self.topk_indices_dtype,
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
#
|
||||||
|
# Note: the order here is important. self.fused_experts can override
|
||||||
|
# cutlass or fused_experts.
|
||||||
|
#
|
||||||
|
if self.fused_experts is not None:
|
||||||
|
return self.fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
assert not renormalize
|
assert not renormalize
|
||||||
assert activation == 'silu', (
|
assert activation == 'silu', (
|
||||||
f"Expected 'silu' activation but got {activation}")
|
f"Expected 'silu' activation but got {activation}")
|
||||||
if self.fused_experts is not None:
|
return flashinfer_cutlass_moe_fp8(
|
||||||
return self.fused_experts(
|
x,
|
||||||
x,
|
layer,
|
||||||
layer.w13_weight,
|
topk_weights,
|
||||||
layer.w2_weight,
|
topk_ids,
|
||||||
topk_weights,
|
inplace=False,
|
||||||
topk_ids,
|
activation=activation,
|
||||||
inplace=False,
|
global_num_experts=global_num_experts,
|
||||||
activation=activation,
|
expert_map=expert_map,
|
||||||
global_num_experts=global_num_experts,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
expert_map=expert_map,
|
)
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
else:
|
||||||
)
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
else:
|
fused_experts)
|
||||||
return flashinfer_cutlass_moe_fp8(
|
assert self.moe_quant_config is not None
|
||||||
x,
|
|
||||||
layer,
|
return fused_experts(
|
||||||
topk_weights,
|
x,
|
||||||
topk_ids,
|
layer.w13_weight,
|
||||||
inplace=False,
|
layer.w2_weight,
|
||||||
activation=activation,
|
topk_weights=topk_weights,
|
||||||
global_num_experts=global_num_experts,
|
topk_ids=topk_ids,
|
||||||
expert_map=expert_map,
|
inplace=True,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
activation=activation,
|
||||||
)
|
quant_config=self.moe_quant_config,
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
global_num_experts=global_num_experts,
|
||||||
fused_experts)
|
expert_map=expert_map,
|
||||||
return fused_experts(
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
x,
|
)
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
activation=activation,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=False,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptNvFp4Config(QuantizationConfig):
|
class ModelOptNvFp4Config(QuantizationConfig):
|
||||||
@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
" for ModelOptNvFp4FusedMoE.")
|
" for ModelOptNvFp4FusedMoE.")
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
moe: FusedMoEConfig,
|
if (self.use_marlin
|
||||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
or (self.allow_flashinfer and self.flashinfer_moe_backend
|
||||||
if (self.allow_flashinfer and self.flashinfer_moe_backend
|
== FlashinferMoeBackend.TENSORRT_LLM)):
|
||||||
== FlashinferMoeBackend.CUTLASS):
|
return None
|
||||||
|
elif (self.allow_flashinfer
|
||||||
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
|
||||||
|
# For now, fp4 moe only works with the flashinfer dispatcher.
|
||||||
prepare_finalize = (
|
prepare_finalize = (
|
||||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe))
|
||||||
moe,
|
|
||||||
a1_gscale=self.layer.w13_input_scale_quant,
|
|
||||||
))
|
|
||||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
|
else:
|
||||||
return super().maybe_make_prepare_finalize(moe)
|
return super().maybe_make_prepare_finalize()
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
experts = select_nvfp4_gemm_impl(
|
experts = select_nvfp4_gemm_impl(
|
||||||
moe,
|
self.moe,
|
||||||
g1_alphas=self.layer.g1_alphas,
|
self.moe_quant_config,
|
||||||
g2_alphas=self.layer.g2_alphas,
|
|
||||||
a1_gscale=self.layer.w13_input_scale_quant,
|
|
||||||
a2_gscale=self.layer.w2_input_scale_quant,
|
|
||||||
allow_flashinfer=self.allow_flashinfer,
|
allow_flashinfer=self.allow_flashinfer,
|
||||||
)
|
)
|
||||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
if (self.use_marlin or self.flashinfer_moe_backend
|
||||||
|
== FlashinferMoeBackend.TENSORRT_LLM):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return nvfp4_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
g1_alphas=layer.g1_alphas,
|
||||||
|
g2_alphas=layer.g2_alphas,
|
||||||
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=layer.w2_input_scale_quant,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
|
||||||
if self.allow_flashinfer and \
|
if (self.allow_flashinfer and self.flashinfer_moe_backend
|
||||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
== FlashinferMoeBackend.TENSORRT_LLM):
|
||||||
import flashinfer
|
import flashinfer
|
||||||
|
|
||||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
a1_gscale = layer.w13_input_scale_quant
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
(hidden_states_fp4,
|
(hidden_states_fp4,
|
||||||
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||||
@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
indices_type=self.topk_indices_dtype)
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Note: the order here is important. self.fused_experts can override
|
||||||
|
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
|
||||||
|
# trtllm.
|
||||||
|
#
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
|
assert self.fused_experts is None
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
workspace=layer.workspace)
|
workspace=layer.workspace)
|
||||||
|
|
||||||
if self.fused_experts is not None:
|
elif self.fused_experts is not None:
|
||||||
assert self.allow_flashinfer and \
|
assert self.allow_flashinfer and \
|
||||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||||
|
|
||||||
@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
x, layer.w13_weight, layer.w2_weight), (
|
x, layer.w13_weight, layer.w2_weight), (
|
||||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||||
|
|
||||||
out = self.fused_experts(
|
return self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
elif (self.allow_flashinfer
|
elif (self.allow_flashinfer
|
||||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||||
flashinfer_cutlass_moe_fp4)
|
flashinfer_cutlass_moe_fp4)
|
||||||
|
assert self.moe_quant_config is not None
|
||||||
|
|
||||||
out = flashinfer_cutlass_moe_fp4(
|
return flashinfer_cutlass_moe_fp4(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
w1_scale=layer.w13_weight_scale,
|
quant_config=self.moe_quant_config,
|
||||||
w2_scale=layer.w2_weight_scale,
|
inplace=False,
|
||||||
g1_alphas=layer.g1_alphas,
|
|
||||||
g2_alphas=layer.g2_alphas,
|
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
|
||||||
a2_gscale=layer.w2_input_scale_quant,
|
|
||||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
|
||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
# only (no EP).
|
# only (no EP).
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
cutlass_moe_fp4)
|
cutlass_moe_fp4)
|
||||||
out = cutlass_moe_fp4(
|
assert self.moe_quant_config is not None
|
||||||
|
return cutlass_moe_fp4(
|
||||||
a=x,
|
a=x,
|
||||||
w1_fp4=layer.w13_weight,
|
w1_fp4=layer.w13_weight,
|
||||||
w2_fp4=layer.w2_weight,
|
w2_fp4=layer.w2_weight,
|
||||||
w1_blockscale=layer.w13_weight_scale,
|
|
||||||
w2_blockscale=layer.w2_weight_scale,
|
|
||||||
g1_alphas=layer.g1_alphas,
|
|
||||||
g2_alphas=layer.g2_alphas,
|
|
||||||
a1_gscale=layer.w13_input_scale_quant,
|
|
||||||
a2_gscale=layer.w2_input_scale_quant,
|
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
quant_config=self.moe_quant_config,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
# TODO: derive from arguments
|
||||||
m=x.shape[0],
|
m=x.shape[0],
|
||||||
n=layer.w2_weight.shape[2] * 2,
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
k=x.shape[1],
|
k=x.shape[1],
|
||||||
e=layer.w13_weight.shape[0],
|
e=layer.w13_weight.shape[0],
|
||||||
expert_map=expert_map,
|
)
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|||||||
@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||||
|
int8_w8a16_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
layer.register_parameter(key, param)
|
layer.register_parameter(key, param)
|
||||||
set_weight_attrs(param, extra_weight_attrs)
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
weight_bits = self.quant_config.weight_bits
|
||||||
|
has_zp = self.quant_config.has_zp
|
||||||
|
assert weight_bits == 4 or weight_bits == 8
|
||||||
|
config_builder = (int4_w4a16_moe_quant_config
|
||||||
|
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||||
|
|
||||||
|
return config_builder(
|
||||||
|
w1_scale=layer.w13_scales,
|
||||||
|
w2_scale=layer.w2_scales,
|
||||||
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||||
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||||
|
block_shape=[0, layer.group_size],
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
indices_type=self.topk_indices_dtype)
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
weight_bits = self.quant_config.weight_bits
|
|
||||||
has_zp = self.quant_config.has_zp
|
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
layer.w13_qweight,
|
layer.w13_qweight,
|
||||||
@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
|
||||||
use_int8_w8a16=weight_bits == 8,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_scales,
|
quant_config=self.moe_quant_config,
|
||||||
w2_scale=layer.w2_scales,
|
)
|
||||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
|
||||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
|
||||||
block_shape=[0, layer.group_size])
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_weight_loader(layer, weight_loader):
|
def get_weight_loader(layer, weight_loader):
|
||||||
|
|||||||
@ -12,6 +12,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
return tile_tokens_dim
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
|
||||||
|
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||||
|
w1_scale = layer.w13_precision_config
|
||||||
|
w2_scale = layer.w2_precision_config
|
||||||
|
else:
|
||||||
|
w1_scale = layer.w13_weight_scale
|
||||||
|
w2_scale = layer.w2_weight_scale
|
||||||
|
|
||||||
|
return mxfp4_w4a4_moe_quant_config(
|
||||||
|
w1_bias=layer.w13_bias,
|
||||||
|
w2_bias=layer.w2_bias,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
moe: FusedMoEConfig,
|
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
if (prepare_finalize.activation_format ==
|
if (prepare_finalize.activation_format ==
|
||||||
@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
"gemm1_alpha": layer.gemm1_alpha,
|
"gemm1_alpha": layer.gemm1_alpha,
|
||||||
"gemm1_beta": layer.gemm1_beta,
|
"gemm1_beta": layer.gemm1_beta,
|
||||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||||
"w13_bias": layer.w13_bias,
|
# TODO(bnell): part of quant_config
|
||||||
"w2_bias": layer.w2_bias,
|
|
||||||
"max_capture_size": self.max_capture_size,
|
"max_capture_size": self.max_capture_size,
|
||||||
}
|
}
|
||||||
return TrtLlmGenExperts(moe, **kwargs)
|
assert self.moe_quant_config is not None
|
||||||
|
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||||
|
**kwargs)
|
||||||
else:
|
else:
|
||||||
# Use matmul_ogs from triton_kernels here!
|
# Use matmul_ogs from triton_kernels here!
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
renormalize=renormalize,
|
renormalize=renormalize,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_bias=layer.w13_bias,
|
quant_config=self.moe_quant_config,
|
||||||
w2_bias=layer.w2_bias,
|
|
||||||
w1_precision=self.w13_precision_config,
|
|
||||||
w2_precision=self.w2_precision_config,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -11,6 +11,9 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||||
|
mxfp4_w4a4_moe_quant_config)
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
is_rocm_aiter_moe_enabled)
|
is_rocm_aiter_moe_enabled)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
self.fused_experts_func = fused_experts
|
self.fused_experts_func = fused_experts
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
per_act_token_quant=self.weight_qscheme == "per_channel",
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=True,
|
quant_config=self.moe_quant_config,
|
||||||
per_channel_quant=self.weight_qscheme == "per_channel",
|
|
||||||
w1_scale=layer.w13_weight_scale,
|
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
use_fp8_w8a8=True,
|
|
||||||
per_channel_quant=self.weight_qscheme == "per_channel",
|
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
quant_config=self.moe_quant_config)
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale)
|
|
||||||
|
|
||||||
|
|
||||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||||
@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|||||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
return mxfp4_w4a4_moe_quant_config(
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=None,
|
||||||
|
a2_scale=None,
|
||||||
|
block_shape=None,
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_mxfp4_w4a4=True,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_weight_scale,
|
quant_config=self.moe_quant_config,
|
||||||
w2_scale=layer.w2_weight_scale,
|
|
||||||
a1_scale=None,
|
|
||||||
a2_scale=None,
|
|
||||||
block_shape=None,
|
|
||||||
activation=activation,
|
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||||
|
int8_w8a16_moe_quant_config)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
|||||||
fix_weights(layer, "w13_weight", weight_bits == 4)
|
fix_weights(layer, "w13_weight", weight_bits == 4)
|
||||||
fix_weights(layer, "w2_weight", weight_bits == 4)
|
fix_weights(layer, "w2_weight", weight_bits == 4)
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
weight_bits = self.quant_config.weight_bits
|
||||||
|
group_size = self.quant_config.group_size
|
||||||
|
assert weight_bits == 4 or weight_bits == 8
|
||||||
|
config_builder = (int4_w4a16_moe_quant_config
|
||||||
|
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||||
|
return config_builder(
|
||||||
|
w1_scale=layer.w13_scale,
|
||||||
|
w2_scale=layer.w2_scale,
|
||||||
|
w1_zp=None,
|
||||||
|
w2_zp=None,
|
||||||
|
block_shape=[0, group_size],
|
||||||
|
)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
indices_type=self.topk_indices_dtype)
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
weight_bits = self.quant_config.weight_bits
|
return fused_experts(
|
||||||
group_size = self.quant_config.group_size
|
|
||||||
|
|
||||||
ret = fused_experts(
|
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
use_int4_w4a16=weight_bits == 4,
|
|
||||||
use_int8_w8a16=weight_bits == 8,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
w1_scale=layer.w13_scale,
|
|
||||||
w2_scale=layer.w2_scale,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
block_shape=[0, group_size])
|
quant_config=self.moe_quant_config,
|
||||||
|
)
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
FlashInferExperts)
|
FlashInferExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
) -> mk.FusedMoEPrepareAndFinalize:
|
|
||||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||||
use_dp = moe.moe_parallel_config.dp_size > 1
|
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
|
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||||
|
|
||||||
|
|
||||||
def select_nvfp4_gemm_impl(
|
def select_nvfp4_gemm_impl(
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
g1_alphas: torch.Tensor,
|
moe_quant_config: FusedMoEQuantConfig,
|
||||||
g2_alphas: torch.Tensor,
|
|
||||||
a1_gscale: torch.Tensor,
|
|
||||||
a2_gscale: torch.Tensor,
|
|
||||||
allow_flashinfer: bool,
|
allow_flashinfer: bool,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||||
|
|
||||||
if allow_flashinfer:
|
if allow_flashinfer:
|
||||||
return FlashInferExperts(
|
return FlashInferExperts(
|
||||||
g1_alphas=g1_alphas,
|
|
||||||
g2_alphas=g2_alphas,
|
|
||||||
a1_gscale=a1_gscale,
|
|
||||||
a2_gscale=a2_gscale,
|
|
||||||
out_dtype=moe.in_dtype,
|
out_dtype=moe.in_dtype,
|
||||||
quant_dtype="nvfp4",
|
quant_config=moe_quant_config,
|
||||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||||
ep_size=moe.moe_parallel_config.ep_size,
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
|
|||||||
@ -8,7 +8,8 @@ import torch
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
FlashInferExperts)
|
FlashInferExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from flashinfer.fused_moe import RoutingMethodType
|
from flashinfer.fused_moe import RoutingMethodType
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||||
assert layer.output1_scales_scalar is not None, (
|
assert layer.output1_scales_scalar is not None, (
|
||||||
"Expected output1_scales_scalar to be initialized")
|
"Expected output1_scales_scalar to be initialized")
|
||||||
assert layer.output1_scales_scalar is not None, (
|
assert layer.output1_scales_scalar is not None, (
|
||||||
@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
moe: Optional[FusedMoEConfig],
|
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
layer: torch.nn.Module,
|
|
||||||
) -> mk.FusedMoEPrepareAndFinalize:
|
|
||||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||||
use_dp, a1_gscale=layer.w13_input_scale)
|
|
||||||
|
|
||||||
|
|
||||||
def select_cutlass_fp8_gemm_impl(
|
def select_cutlass_fp8_gemm_impl(
|
||||||
moe: Optional[FusedMoEConfig],
|
moe: Optional[FusedMoEConfig],
|
||||||
layer: torch.nn.Module,
|
quant_config: FusedMoEQuantConfig,
|
||||||
out_dtype: Optional[torch.dtype] = None,
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||||
|
|
||||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
|
||||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
|
||||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
|
||||||
|
|
||||||
if moe is not None:
|
if moe is not None:
|
||||||
return FlashInferExperts(
|
return FlashInferExperts(
|
||||||
g1_alphas=layer.output1_scales_gate_scalar,
|
|
||||||
g2_alphas=layer.output2_scales_scalar,
|
|
||||||
a1_gscale=layer.w13_input_scale,
|
|
||||||
a2_gscale=layer.w2_input_scale_inv,
|
|
||||||
out_dtype=moe.in_dtype,
|
out_dtype=moe.in_dtype,
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
quant_config=quant_config,
|
||||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||||
ep_size=moe.moe_parallel_config.ep_size,
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl(
|
|||||||
assert out_dtype is not None, (
|
assert out_dtype is not None, (
|
||||||
"If moe config is None, out_dtype must be passed")
|
"If moe config is None, out_dtype must be passed")
|
||||||
return FlashInferExperts(
|
return FlashInferExperts(
|
||||||
g1_alphas=layer.output1_scales_gate_scalar,
|
|
||||||
g2_alphas=layer.output2_scales_scalar,
|
|
||||||
a1_gscale=layer.w13_input_scale,
|
|
||||||
a2_gscale=layer.w2_input_scale_inv,
|
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
quant_dtype=torch.float8_e4m3fn,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8(
|
|||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||||
|
assert quant_config is not None
|
||||||
|
|
||||||
fused_experts = mk.FusedMoEModularKernel(
|
fused_experts = mk.FusedMoEModularKernel(
|
||||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
||||||
layer=layer),
|
|
||||||
select_cutlass_fp8_gemm_impl(moe=None,
|
select_cutlass_fp8_gemm_impl(moe=None,
|
||||||
layer=layer,
|
quant_config=quant_config,
|
||||||
out_dtype=hidden_states.dtype))
|
out_dtype=hidden_states.dtype))
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
@ -411,6 +411,7 @@ def per_token_group_quant_fp8(
|
|||||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||||
|
|
||||||
# prefer CUDA kernel if available
|
# prefer CUDA kernel if available
|
||||||
|
# TODO(bnell): this causes some fp8 moe test to fail.
|
||||||
if current_platform.is_cuda() and x.is_contiguous():
|
if current_platform.is_cuda() and x.is_contiguous():
|
||||||
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
|
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
|
||||||
fp8_min, fp8_max, use_ue8m0)
|
fp8_min, fp8_max, use_ue8m0)
|
||||||
|
|||||||
@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||||
get_act_fn)
|
get_act_fn)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (activation_without_mul,
|
||||||
fused_topk, torch_vllm_outplace_fused_experts)
|
fused_topk)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -230,7 +230,7 @@ class NomicMoE(nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.total_intermediate_size = intermediate_size
|
self.total_intermediate_size = intermediate_size
|
||||||
self.intermediate_size = divide(intermediate_size, self.tp_size)
|
self.intermediate_size = divide(intermediate_size, self.tp_size)
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = activation_without_mul(hidden_act)
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@ -297,14 +297,14 @@ class NomicMoE(nn.Module):
|
|||||||
router_logits,
|
router_logits,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
renormalize=False)
|
renormalize=False)
|
||||||
final_hidden_states = torch_vllm_outplace_fused_experts(
|
|
||||||
|
final_hidden_states = torch.ops.vllm.outplace_fused_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=self.w1,
|
w1=self.w1,
|
||||||
w2=self.w2,
|
w2=self.w2,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
activation=self.hidden_act,
|
activation=self.hidden_act,
|
||||||
is_act_and_mul=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module):
|
|||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
|
||||||
self.w1,
|
topk_weights, topk_ids, _ = fused_topk(
|
||||||
self.w2,
|
hidden_states,
|
||||||
router_logits,
|
router_logits,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
renormalize=self.config.norm_topk_prob,
|
renormalize=self.config.norm_topk_prob)
|
||||||
inplace=True)
|
|
||||||
|
final_hidden_states = fused_experts(hidden_states,
|
||||||
|
self.w1,
|
||||||
|
self.w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
if self.config.n_shared_experts is not None:
|
if self.config.n_shared_experts is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
|
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module):
|
|||||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
|
||||||
self.ws,
|
topk_weights, topk_ids, _ = fused_topk(hidden_states,
|
||||||
self.w2s,
|
router_logits,
|
||||||
router_logits,
|
self.top_k,
|
||||||
self.top_k,
|
renormalize=True)
|
||||||
renormalize=True,
|
|
||||||
inplace=True)
|
final_hidden_states = fused_experts(hidden_states,
|
||||||
|
self.ws,
|
||||||
|
self.w2s,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=True)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
|
|||||||
@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||||
if not (isinstance(module, FusedMoE)
|
if not isinstance(module, FusedMoE):
|
||||||
and module.moe_config.quant_dtype == torch.float8_e4m3fn
|
return False
|
||||||
and module.moe_config.block_shape == deep_gemm_block_shape()):
|
|
||||||
|
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
|
||||||
|
|
||||||
|
if (moe_quant_config is None
|
||||||
|
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
||||||
|
or moe_quant_config.block_shape != deep_gemm_block_shape()):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not isinstance(module.quant_method.fused_experts,
|
if not isinstance(module.quant_method.fused_experts,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user