mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 04:07:56 +08:00
[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
e6585ddb45
commit
5963b98b46
@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.scalar_type import scalar_types
|
||||
@ -140,6 +144,12 @@ def bench_run(
|
||||
a_fp8_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
@ -147,10 +157,7 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe_fp4(
|
||||
@ -172,25 +179,27 @@ def bench_run(
|
||||
device: torch.device,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||
cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -211,26 +220,29 @@ def bench_run(
|
||||
e: int,
|
||||
device: torch.device,
|
||||
):
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_alphas,
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -246,16 +258,18 @@ def bench_run(
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
|
||||
@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts,
|
||||
@ -96,6 +97,11 @@ def bench_run(
|
||||
a_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
@ -103,10 +109,7 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe(
|
||||
@ -125,6 +128,12 @@ def bench_run(
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(
|
||||
a,
|
||||
@ -132,14 +141,11 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -156,6 +162,12 @@ def bench_run(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
@ -165,14 +177,11 @@ def bench_run(
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -185,6 +194,11 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
@ -194,10 +208,7 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
|
||||
@ -14,6 +14,10 @@ import ray
|
||||
import torch
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
@ -134,43 +138,36 @@ def benchmark_config(
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
|
||||
if use_fp8_w8a8:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif use_int8_w8a16:
|
||||
quant_dtype = torch.int8
|
||||
else:
|
||||
quant_dtype = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=quant_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, False
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
else:
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, renormalize=not use_deep_gemm
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
@ -414,7 +411,7 @@ class BenchmarkWorker:
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
@ -547,7 +544,7 @@ def save_configs(
|
||||
block_quant_shape: list[int],
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .mk_objects import (expert_info, make_fused_experts,
|
||||
from .mk_objects import (TestMoEQuantConfig, expert_info, make_fused_experts,
|
||||
make_prepare_finalize, prepare_finalize_info)
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
|
||||
@ -40,7 +40,7 @@ class Config:
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
quant_config: Optional[TestMoEQuantConfig]
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
@ -52,7 +52,7 @@ class Config:
|
||||
|
||||
def __post_init__(self):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
self.quant_config = TestMoEQuantConfig(None, False, False, None)
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
@ -275,21 +275,19 @@ class WeightTensors:
|
||||
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||
|
||||
def to_current_device(self):
|
||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||
device = torch.cuda.current_device()
|
||||
self.w1 = self.w1.to(device=device)
|
||||
self.w2 = self.w2.to(device=device)
|
||||
|
||||
if self.is_quantized():
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
self.w1_scale = self.w1_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
self.w2_scale = self.w2_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
if self.w1_scale is not None:
|
||||
self.w1_scale = self.w1_scale.to(device=device)
|
||||
if self.w2_scale is not None:
|
||||
self.w2_scale = self.w2_scale.to(device=device)
|
||||
|
||||
if self.w1_gs is not None:
|
||||
assert self.w2_gs is not None
|
||||
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
|
||||
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
|
||||
self.w1_gs = self.w1_gs.to(device=device)
|
||||
if self.w2_gs is not None:
|
||||
self.w2_gs = self.w2_gs.to(device=device)
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
@ -297,20 +295,12 @@ class WeightTensors:
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
|
||||
w1_scale, w2_scale = (None, None)
|
||||
if self.is_quantized():
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
w1_scale = self.w1_scale[s:e, :, :]
|
||||
w2_scale = self.w2_scale[s:e, :, :]
|
||||
|
||||
w1_gs = self.w1_gs
|
||||
w2_gs = self.w2_gs
|
||||
if w1_gs is not None:
|
||||
assert w2_gs is not None
|
||||
w1_gs = w1_gs[s:e]
|
||||
w2_gs = w2_gs[s:e]
|
||||
w1_scale = self.w1_scale[
|
||||
s:e, :, :] if self.w1_scale is not None else None
|
||||
w2_scale = self.w2_scale[
|
||||
s:e, :, :] if self.w2_scale is not None else None
|
||||
w1_gs = self.w1_gs[s:e] if self.w1_gs is not None else None
|
||||
w2_gs = self.w2_gs[s:e] if self.w2_gs is not None else None
|
||||
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||
|
||||
@ -323,7 +313,8 @@ class WeightTensors:
|
||||
in_dtype=config.dtype,
|
||||
quant_dtype=config.quant_dtype,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_out_ch_quant,
|
||||
per_out_ch_quant=config.
|
||||
is_per_act_token_quant, # or config.is_per_out_ch_quant
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
@ -342,8 +333,6 @@ class RankTensors:
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
@ -426,7 +415,6 @@ class RankTensors:
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_config=config.quant_config,
|
||||
)
|
||||
|
||||
|
||||
@ -522,10 +510,16 @@ def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
and config.supports_apply_weight_on_input())
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def make_modular_kernel(
|
||||
config: Config,
|
||||
vllm_config: VllmConfig,
|
||||
weights: WeightTensors,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
@ -548,20 +542,20 @@ def make_modular_kernel(
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
quant_config=config.quant_config,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||
config.all2all_backend(), moe)
|
||||
config.all2all_backend(), moe,
|
||||
quant_config)
|
||||
|
||||
fused_experts = make_fused_experts(
|
||||
config.fused_experts_type,
|
||||
moe,
|
||||
quant_config,
|
||||
prepare_finalize.num_dispatchers(),
|
||||
weights.w1_gs,
|
||||
weights.w2_gs,
|
||||
config.N,
|
||||
)
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
@ -583,12 +577,38 @@ def run_modular_kernel(
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config, weights)
|
||||
if config.quant_dtype == "nvfp4":
|
||||
gscale = _make_gscale(config.num_local_experts)
|
||||
else:
|
||||
gscale = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
config.quant_dtype,
|
||||
w1_scale=rank_weights.w1_scale,
|
||||
w2_scale=rank_weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
g1_alphas=(1 / rank_weights.w1_gs)
|
||||
if rank_weights.w1_gs is not None else None,
|
||||
g2_alphas=(1 / rank_weights.w2_gs)
|
||||
if rank_weights.w2_gs is not None else None,
|
||||
a1_gscale=gscale,
|
||||
a2_gscale=gscale,
|
||||
block_shape=config.quant_block_shape,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
per_out_ch_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config, quant_config)
|
||||
|
||||
# impls might update the tensor in place
|
||||
hidden_states = rank_tensors.hidden_states.clone()
|
||||
|
||||
topk_ids = rank_tensors.topk_ids.to(
|
||||
mk.prepare_finalize.topk_indices_dtype())
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states":
|
||||
rank_tensors.hidden_states.clone(
|
||||
), # impls might update the tensor in place
|
||||
hidden_states,
|
||||
"w1":
|
||||
rank_weights.w1,
|
||||
"w2":
|
||||
@ -596,15 +616,9 @@ def run_modular_kernel(
|
||||
"topk_weights":
|
||||
rank_tensors.topk_weights,
|
||||
"topk_ids":
|
||||
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
|
||||
topk_ids,
|
||||
"expert_map":
|
||||
rank_tensors.expert_map,
|
||||
"w1_scale":
|
||||
rank_weights.w1_scale,
|
||||
"w2_scale":
|
||||
rank_weights.w2_scale,
|
||||
"a1_scale":
|
||||
rank_tensors.hidden_states_scale,
|
||||
"global_num_experts":
|
||||
config.E,
|
||||
"apply_router_weight_on_input":
|
||||
|
||||
@ -10,7 +10,8 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||
@ -86,7 +87,7 @@ def make_feature_matrix(csv_file_path: str):
|
||||
quant_config_dict = config_dict['quant_config']
|
||||
del config_dict['quant_config']
|
||||
if quant_config_dict is None:
|
||||
quant_config = FusedMoEQuantConfig(None)
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
|
||||
@ -32,6 +32,14 @@ from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMoEQuantConfig:
|
||||
quant_dtype: Union[torch.dtype, str, None]
|
||||
per_out_ch_quant: bool
|
||||
per_act_token_quant: bool
|
||||
block_shape: Optional[list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareFinalizeInfo:
|
||||
activation_format: mk.FusedMoEActivationFormat
|
||||
@ -66,7 +74,7 @@ common_float_types: list[Union[torch.dtype, str]] = [
|
||||
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
||||
]
|
||||
common_float_and_int_types = common_float_types + [torch.int8]
|
||||
nv_fp4_types = ["nvfp4"]
|
||||
nvfp4_types = ["nvfp4"]
|
||||
fp8_types = [torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@ -219,7 +227,7 @@ if (has_flashinfer_cutlass_fused_moe()
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
backend=None,
|
||||
force_multigpu=True,
|
||||
@ -229,7 +237,7 @@ if (has_flashinfer_cutlass_fused_moe()
|
||||
register_experts(
|
||||
FlashInferExperts,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
# Note: this is a hack to get it to run for now
|
||||
@ -306,39 +314,39 @@ if cutlass_fp4_supported():
|
||||
register_experts(
|
||||
CutlassExpertsFp4,
|
||||
standard_format,
|
||||
nv_fp4_types,
|
||||
nvfp4_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=True,
|
||||
supports_expert_map=False,
|
||||
)
|
||||
|
||||
MK_QUANT_CONFIGS = [
|
||||
MK_QUANT_CONFIGS: list[Optional[TestMoEQuantConfig]] = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
TestMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
@ -346,33 +354,27 @@ MK_QUANT_CONFIGS = [
|
||||
|
||||
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||
MK_QUANT_CONFIGS += [
|
||||
FusedMoEQuantConfig(quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
TestMoEQuantConfig(quant_dtype="nvfp4",
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
]
|
||||
|
||||
|
||||
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||
return torch.ones((num_experts, ),
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def make_prepare_finalize(
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||
backend: Optional[str],
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
if backend != "naive" and backend is not None:
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
moe, quant_config)
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||
a1_gscale=_make_gscale(moe.num_local_experts),
|
||||
)
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
@ -383,34 +385,39 @@ def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
||||
return t[s:e]
|
||||
|
||||
|
||||
def make_cutlass_strides(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
return ab_strides1, ab_strides2, c_strides1, c_strides2
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
num_dispatchers: int,
|
||||
w1_gs: Optional[torch.Tensor],
|
||||
w2_gs: Optional[torch.Tensor],
|
||||
N: int,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": moe.block_shape,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"quant_config": quant_config,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
torch.set_printoptions(threshold=0, edgeitems=0, linewidth=10000)
|
||||
|
||||
if fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": moe.block_shape,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
}
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonExperts:
|
||||
@ -422,8 +429,8 @@ def make_fused_experts(
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
print("Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
elif fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
@ -437,62 +444,50 @@ def make_fused_experts(
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
}
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||
strides = make_cutlass_strides(moe.num_experts, N, moe.hidden_dim)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
}
|
||||
"ab_strides1": strides[0],
|
||||
"ab_strides2": strides[1],
|
||||
"c_strides1": strides[2],
|
||||
"c_strides2": strides[3],
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||
elif fused_experts_type == CutlassExpertsFp4:
|
||||
assert w1_gs is not None and w2_gs is not None
|
||||
num_experts = moe.num_local_experts
|
||||
rank = moe.moe_parallel_config.dp_rank
|
||||
kwargs = {
|
||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||
"a1_gscale": _make_gscale(num_experts),
|
||||
"a2_gscale": _make_gscale(num_experts),
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": moe.per_act_token_quant,
|
||||
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||
"block_shape": moe.block_shape,
|
||||
"max_experts_per_worker": moe.num_local_experts,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
"out_dtype": moe.in_dtype,
|
||||
} | quant_kwargs
|
||||
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||
experts = CutlassExpertsFp4(**kwargs)
|
||||
elif fused_experts_type == FlashInferExperts:
|
||||
assert w1_gs is not None and w2_gs is not None
|
||||
num_experts = moe.num_local_experts
|
||||
rank = moe.moe_parallel_config.dp_rank
|
||||
kwargs = {
|
||||
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||
"a1_gscale": _make_gscale(num_experts),
|
||||
"a2_gscale": _make_gscale(num_experts),
|
||||
"out_dtype": moe.in_dtype,
|
||||
"quant_dtype": "nvfp4",
|
||||
"ep_rank": moe.ep_rank,
|
||||
"ep_size": moe.ep_size,
|
||||
"tp_rank": moe.tp_rank,
|
||||
"tp_size": moe.tp_size,
|
||||
}
|
||||
} | quant_kwargs
|
||||
print(f"Making FlashInferExperts {kwargs} ...")
|
||||
experts = FlashInferExperts(**kwargs)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||
|
||||
torch.set_printoptions(threshold=1000, edgeitems=5, linewidth=80)
|
||||
|
||||
return experts
|
||||
|
||||
@ -6,6 +6,8 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@ -56,13 +58,18 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
rank=0,
|
||||
)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
per_act_token_quant=False,
|
||||
block_shape=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# triton (reference)
|
||||
triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=BLOCK_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_triton = FusedMoEModularKernel(prep_finalize, triton_experts)
|
||||
|
||||
@ -73,8 +80,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
@ -82,8 +87,7 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
deepgemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
block_shape=BLOCK_SIZE,
|
||||
per_act_token_quant=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk_deepgemm = FusedMoEModularKernel(prep_finalize, deepgemm_experts)
|
||||
|
||||
@ -94,8 +98,6 @@ def test_batched_deepgemm_vs_triton(E: int, T: int, K: int, N: int, topk: int,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
global_num_experts=E,
|
||||
)
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||
@ -250,7 +250,7 @@ def test_fused_moe_batched_experts(
|
||||
block_shape=block_shape,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
if input_scales and quant_dtype is not None:
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config, make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
@ -161,22 +161,17 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
m_fused_moe = modular_triton_fused_moe(quant_config)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
@ -186,37 +181,24 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_s,
|
||||
w2_s,
|
||||
quant_config.w1_scale,
|
||||
quant_config.w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
m_out = m_fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
m_out = m_fused_moe(a, w1, w2, topk_weights, topk_ids)
|
||||
|
||||
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
||||
tol = 0.035 if M < 40000 else 0.039
|
||||
# 0.039 only needed for M >= 8192
|
||||
tol = 0.035 if M < 8192 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
@ -248,14 +230,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
|
||||
@ -4,12 +4,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
@ -50,7 +50,7 @@ MNK_FACTORS = [
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 512),
|
||||
(2048, 4096, 7168),
|
||||
(2048, 4096, 4096),
|
||||
]
|
||||
|
||||
E = [8, 24]
|
||||
@ -117,31 +117,28 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||
_) = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
w1, w2, quant_config = make_test_quant_config(
|
||||
E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
quant_dtype=torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
out = fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, quant_config.w1_scale,
|
||||
quant_config.w2_scale, score, topk,
|
||||
block_size)
|
||||
|
||||
# Check results
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import dataclasses
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
@ -9,6 +10,8 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8, run_cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
@ -154,7 +157,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||
"c_strides2", "w1_scale", "w2_scale"
|
||||
"c_strides2"
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
@ -162,6 +165,8 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
quant_config = cutlass_moe_kwargs["quant_config"]
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
@ -178,6 +183,12 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
for k, t in full_tensors.items():
|
||||
cutlass_moe_kwargs[k] = t[s:e]
|
||||
|
||||
new_quant_config = copy.deepcopy(quant_config)
|
||||
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
|
||||
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
|
||||
|
||||
cutlass_moe_kwargs["quant_config"] = new_quant_config
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
@ -191,6 +202,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
@ -199,20 +211,27 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
]
|
||||
])
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=moe_tensors.w1_scale,
|
||||
w2_scale=moe_tensors.w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
# Set to moe_tensors.a_scale iff static scales + per tensor.
|
||||
# This is not currently being tested.
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
'a': moe_tensors.a,
|
||||
'w1_q': moe_tensors.w1_q, # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q, # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
'quant_config': quant_config,
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@ -261,16 +280,23 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
if ep_size is not None:
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
number_local_experts = e // ep_size
|
||||
else:
|
||||
number_local_experts = None
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token,
|
||||
number_local_experts)
|
||||
per_out_ch, number_local_experts)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
@ -315,14 +341,19 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
triton_output = fused_experts(mt.a_d,
|
||||
mt.w1_d,
|
||||
mt.w2_d,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token)
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
|
||||
@ -15,6 +15,8 @@ from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
@ -71,9 +73,12 @@ def make_block_quant_fp8_weights(
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||
_) = make_test_weights(e, n, k, torch.bfloat16,
|
||||
_) = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
block_size)
|
||||
block_shape=block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@ -130,10 +135,11 @@ class TestTensors:
|
||||
config=config)
|
||||
|
||||
|
||||
def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
max_tokens_per_rank: int, dp_size: int,
|
||||
hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||
def make_ll_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, max_tokens_per_rank: int,
|
||||
dp_size: int, hidden_size: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
assert test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is not None
|
||||
@ -154,17 +160,18 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
num_dispatchers=pgi.world_size // dp_size,
|
||||
block_shape=test_config.block_size,
|
||||
per_act_token_quant=test_config.per_act_token_quant)
|
||||
quant_config=quant_config,
|
||||
)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
dp_size: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig) -> FusedMoEModularKernel:
|
||||
def make_ht_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, q_dtype: Optional[torch.dtype],
|
||||
test_config: TestConfig,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
assert not test_config.low_latency
|
||||
assert test_config.use_fp8_dispatch is None
|
||||
@ -178,15 +185,16 @@ def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
|
||||
fused_experts = DeepGemmExperts()
|
||||
fused_experts = DeepGemmExperts(quant_config)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int,
|
||||
test_tensors: TestTensors) -> FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
num_local_experts: int, test_tensors: TestTensors,
|
||||
quant_config: FusedMoEQuantConfig) -> FusedMoEModularKernel:
|
||||
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
test_config = test_tensors.config
|
||||
@ -204,10 +212,16 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int,
|
||||
dp_size=dp_size,
|
||||
hidden_size=hidden_size,
|
||||
q_dtype=q_dtype,
|
||||
test_config=test_config)
|
||||
test_config=test_config,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts,
|
||||
q_dtype, test_config)
|
||||
mk = make_ht_modular_kernel(pg,
|
||||
pgi,
|
||||
dp_size,
|
||||
num_local_experts,
|
||||
q_dtype,
|
||||
test_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return mk
|
||||
|
||||
@ -233,17 +247,23 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
return expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale=(None if test_config.low_latency else
|
||||
test_tensors.rank_token_scales),
|
||||
block_shape=test_config.block_size,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg=pg,
|
||||
pgi=pgi,
|
||||
dp_size=dp_size,
|
||||
num_local_experts=num_local_experts,
|
||||
test_tensors=test_tensors)
|
||||
|
||||
# Low-Latency kernels can't dispatch scales.
|
||||
a1_scale = (None
|
||||
if test_config.low_latency else test_tensors.rank_token_scales)
|
||||
test_tensors=test_tensors,
|
||||
quant_config=quant_config)
|
||||
|
||||
out = mk.forward(hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
@ -254,12 +274,6 @@ def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
return out
|
||||
|
||||
@ -269,6 +283,13 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a1_scale: torch.Tensor, block_shape: list[int]):
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=a,
|
||||
w1=w1,
|
||||
@ -276,11 +297,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
# Make sure this is set to False so we
|
||||
# don't end up comparing the same implementation.
|
||||
allow_deep_gemm=False)
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@ -129,11 +130,9 @@ def make_modular_kernel(
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> FusedMoEModularKernel:
|
||||
|
||||
is_quantized = q_dtype is not None
|
||||
|
||||
ht_args: Optional[DeepEPHTArgs] = None
|
||||
ll_args: Optional[DeepEPLLArgs] = None
|
||||
|
||||
@ -159,24 +158,14 @@ def make_modular_kernel(
|
||||
num_dispatchers = pgi.world_size // dp_size
|
||||
|
||||
if low_latency_mode:
|
||||
assert not per_act_token_quant, "not supported in ll mode"
|
||||
assert not quant_config.per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
fused_experts = TritonExperts(quant_config=quant_config)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
@ -217,11 +206,6 @@ def deep_ep_moe_impl(
|
||||
if is_quantized:
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
|
||||
@ -236,6 +220,19 @@ def deep_ep_moe_impl(
|
||||
rank_token_scales_chunk = rank_token_scales_chunk[
|
||||
chunk_start:chunk_end]
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
q_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
)
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
|
||||
|
||||
out = mk.forward(hidden_states=rank_tokens_chunk,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
@ -245,12 +242,6 @@ def deep_ep_moe_impl(
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
a1_scale=rank_token_scales_chunk,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=False)
|
||||
|
||||
if not skip_result_store:
|
||||
@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
mnk: tuple[int, int, int],
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
@ -424,7 +417,6 @@ def test_deep_ep_moe(
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
m, n, k = mnk
|
||||
|
||||
current_platform.seed_everything(7)
|
||||
world_size, dp_size = world_dp_size
|
||||
@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize("m,n,k", MNKs)
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@requires_deep_ep
|
||||
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool):
|
||||
|
||||
def test_low_latency_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
):
|
||||
low_latency_mode = True
|
||||
m, n, k = mnk
|
||||
|
||||
if (low_latency_mode
|
||||
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):
|
||||
|
||||
@ -11,6 +11,8 @@ import math
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@ -94,6 +96,13 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1)
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
# triton reference
|
||||
out_triton = fused_experts(
|
||||
hidden_states=tokens_bf16,
|
||||
@ -102,11 +111,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=False,
|
||||
)
|
||||
|
||||
@ -118,19 +123,14 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_scale=a1_scale,
|
||||
block_shape=block_size,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
diff = calc_diff(out_deepgemm, out_triton)
|
||||
assert diff < 0.001, f"Diff exceeded 1%: {diff}"
|
||||
|
||||
|
||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||
# can trigger the deepgemm path.
|
||||
# Note: N <= 512 will disable the deepgemm path due to performance issues.
|
||||
MNKs = [
|
||||
(1024, 768, 128),
|
||||
(1024, 768, 512),
|
||||
@ -144,15 +144,15 @@ TOPKS = [2, 6]
|
||||
NUM_EXPERTS = [32]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", MNKs)
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNKs)
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(),
|
||||
reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
_fused_moe_mod = importlib.import_module(
|
||||
"vllm.model_executor.layers.fused_moe.fused_moe")
|
||||
@ -168,8 +168,6 @@ def test_deepgemm_vs_triton(mnk, topk, num_experts, monkeypatch):
|
||||
monkeypatch.setattr(_fused_moe_mod, "deep_gemm_moe_fp8",
|
||||
_spy_deep_gemm_moe_fp8)
|
||||
|
||||
m, n, k = mnk
|
||||
|
||||
if topk > num_experts:
|
||||
pytest.skip(f"topk={topk} > num_experts={num_experts}")
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
@ -145,6 +147,14 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
@ -153,15 +163,10 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
@ -210,6 +215,14 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
@ -218,15 +231,10 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.moe.utils import make_test_quant_config
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
@ -41,7 +41,6 @@ MNK_FACTORS = [
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
#@pytest.mark.parametrize("e", [128, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
@ -56,16 +55,15 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
quant_blocksize = 16
|
||||
|
||||
(_, w1_q, w1_blockscale,
|
||||
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a,
|
||||
@ -73,35 +71,17 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
assert w1_gs is not None
|
||||
assert w2_gs is not None
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
FlashInferExperts(
|
||||
a1_gscale=a1_gs,
|
||||
g1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
g2_alphas=(1 / w2_gs),
|
||||
out_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
))
|
||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||
)
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w1_scale=w1_blockscale,
|
||||
w2=w2_q,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=a1_gs,
|
||||
a2_scale=a2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
@ -122,18 +102,18 @@ def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||
w1_blockscale[idx],
|
||||
w1_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||
w2_blockscale[idx],
|
||||
w2_gs[idx],
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w1_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w1_q[idx],
|
||||
quant_config.w1_scale[idx], (1 / quant_config.g1_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w1_q.device,
|
||||
block_size=quant_blocksize)
|
||||
w2_d[idx] = dequantize_nvfp4_to_dtype(
|
||||
w2_q[idx],
|
||||
quant_config.w2_scale[idx], (1 / quant_config.g2_alphas[idx]),
|
||||
dtype=dtype,
|
||||
device=w2_q.device,
|
||||
block_size=quant_blocksize)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@ -293,6 +294,13 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
pc2,
|
||||
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
w1=w1_tri,
|
||||
@ -300,10 +308,7 @@ def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
@ -336,6 +341,13 @@ def batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(
|
||||
max_num_tokens,
|
||||
@ -344,19 +356,12 @@ def batched_moe(
|
||||
rank=0,
|
||||
),
|
||||
BatchedOAITritonExperts(
|
||||
None,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
extra_expert_args = {
|
||||
"w1_bias": w1_bias,
|
||||
"w2_bias": w2_bias,
|
||||
}
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
|
||||
|
||||
return fused_experts(
|
||||
@ -365,7 +370,6 @@ def batched_moe(
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
extra_expert_args=extra_expert_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
@ -22,7 +21,8 @@ from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||
run_modular_kernel)
|
||||
from .modular_kernel_tools.mk_objects import (
|
||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
|
||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, TestMoEQuantConfig,
|
||||
expert_info)
|
||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||
parallel_launch_with_config)
|
||||
|
||||
@ -55,7 +55,7 @@ def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
base_config: Config,
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
@ -63,42 +63,44 @@ def rank_worker(
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
if base_config.fused_moe_chunk_size is not None:
|
||||
assert (
|
||||
base_config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
Ms = base_config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
TOPKs = base_config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
exceptions = []
|
||||
count = 0
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
# override m and topk
|
||||
config = copy.deepcopy(base_config)
|
||||
config.Ms = m
|
||||
config.topks = topk
|
||||
|
||||
try:
|
||||
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||
count = count + 1
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
rank_tensors = RankTensors.make(config, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, config, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
ref_out = reference_moe_impl(config, weights, rank_tensors)
|
||||
|
||||
if config.quant_dtype == "nvfp4":
|
||||
atol = 1e-1
|
||||
rtol = 1e-1
|
||||
atol = 1e-1 if config.K < 4096 else 2e-1
|
||||
rtol = 1e-1 if config.K < 4096 else 2e-1
|
||||
else:
|
||||
atol = 3e-2
|
||||
rtol = 3e-2
|
||||
@ -132,7 +134,7 @@ Ms = [32, 64]
|
||||
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||
Ks = [2048]
|
||||
Ns = [2048]
|
||||
Ns = [1024]
|
||||
TOPKs = [4, 1]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
@ -167,7 +169,7 @@ def is_nyi_config(config: Config) -> bool:
|
||||
@meets_multi_gpu_requirements
|
||||
def test_modular_kernel_combinations_multigpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
@ -208,7 +210,7 @@ def test_modular_kernel_combinations_multigpu(
|
||||
@pytest.mark.parametrize("world_size", [1])
|
||||
def test_modular_kernel_combinations_singlegpu(
|
||||
k: int, n: int, e: int, dtype: torch.dtype,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
quant_config: Optional[TestMoEQuantConfig],
|
||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||
mk.FusedMoEPermuteExpertsUnpermute],
|
||||
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||
|
||||
@ -15,11 +15,14 @@ from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
@ -187,14 +190,9 @@ def test_fused_moe(
|
||||
#
|
||||
# Setup test functions
|
||||
#
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
m_fused_moe_fn = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
use_mxfp4_w4a4=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None)
|
||||
m_fused_moe_fn = modular_triton_fused_moe(quant_config)
|
||||
|
||||
def m_fused_moe(
|
||||
a: torch.Tensor,
|
||||
@ -340,6 +338,18 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
if weight_bits == 4:
|
||||
quant_config_builder = int4_w4a16_moe_quant_config
|
||||
else:
|
||||
assert weight_bits == 8
|
||||
quant_config_builder = int8_w8a16_moe_quant_config
|
||||
|
||||
quant_config = quant_config_builder(w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
@ -347,15 +357,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
quant_config=quant_config)
|
||||
torch_output = torch_moe(a,
|
||||
w1_ref,
|
||||
w2_ref,
|
||||
|
||||
@ -10,6 +10,7 @@ from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import nvfp4_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
@ -56,7 +57,7 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
in_dtype=dtype,
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None, # use quant_blocksize?
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
@ -73,18 +74,22 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||
assert w1_blockscale is not None
|
||||
assert w2_blockscale is not None
|
||||
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
g1_alphas=(1 / w1_gs),
|
||||
g2_alphas=(1 / w2_gs),
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
)
|
||||
|
||||
cutlass_output = cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_q,
|
||||
w1_blockscale=w1_blockscale,
|
||||
g1_alphas=(1 / w1_gs),
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_q,
|
||||
w2_blockscale=w2_blockscale,
|
||||
g2_alphas=(1 / w2_gs),
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=quant_config,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
|
||||
@ -9,6 +9,8 @@ import torch
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
@ -143,10 +145,16 @@ def pplx_cutlass_moe(
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||
out_dtype, per_act_token, per_out_ch,
|
||||
ab_strides1, ab_strides2, c_strides1,
|
||||
c_strides2)
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
num_local_experts, num_dispatchers, out_dtype, ab_strides1,
|
||||
ab_strides2, c_strides1, c_strides2,
|
||||
fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank]))
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@ -167,10 +175,7 @@ def pplx_cutlass_moe(
|
||||
chunk_topk_ids,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None, #TODO
|
||||
w1_scale=chunk_by_rank(w1_scale, rank, world_size),
|
||||
w2_scale=chunk_by_rank(w2_scale, rank, world_size),
|
||||
a1_scale=chunk_by_rank(a1_scale, rank, world_size)
|
||||
if per_act_token else a1_scale[rank])
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ BATCHED_MOE_MNK_FACTORS = [
|
||||
]
|
||||
|
||||
PPLX_COMBOS = [
|
||||
# TODO: figure out why this fails, seems to be test problem
|
||||
# TODO(bnell): figure out why this fails, seems to be test problem
|
||||
#(1, 128, 128),
|
||||
(2, 128, 512),
|
||||
(3, 1024, 2048),
|
||||
@ -360,18 +360,18 @@ def pplx_prepare_finalize(
|
||||
|
||||
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
||||
a_chunk,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
FusedMoEQuantConfig(
|
||||
FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
False,
|
||||
block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
),
|
||||
)
|
||||
|
||||
@ -540,20 +540,6 @@ def pplx_moe(
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: workers with the same dp_rank must use the exact same inputs.
|
||||
a_chunk = chunk_by_rank(a, rank, world_size)
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size)
|
||||
@ -567,6 +553,28 @@ def pplx_moe(
|
||||
a1_scale_chunk = chunk_scales_by_rank(a1_scale, rank, world_size)
|
||||
a2_scale_chunk = chunk_scales_by_rank(a2_scale, rank, world_size)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
shared_experts,
|
||||
)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
@ -585,10 +593,6 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
@ -605,10 +609,6 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
a1_scale=a1_scale_chunk,
|
||||
a2_scale=a2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@ -820,7 +820,7 @@ def test_pplx_moe_slow(
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, e,
|
||||
@ -897,7 +897,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
)
|
||||
args["w1"] = w1
|
||||
args["w2"] = w2
|
||||
|
||||
@ -7,10 +7,12 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import fused_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
@ -152,11 +154,12 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True, # using fp8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
quant_config=fp8_w8a8_moe_quant_config(
|
||||
per_act_token_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
),
|
||||
)
|
||||
|
||||
# Check results
|
||||
|
||||
@ -9,7 +9,8 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
@ -34,18 +35,22 @@ def triton_moe(
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_act_token_quant,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
@ -64,6 +69,16 @@ def batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@ -72,21 +87,11 @@ def batched_moe(
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def naive_batched_moe(
|
||||
@ -105,6 +110,16 @@ def naive_batched_moe(
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
@ -113,21 +128,11 @@ def naive_batched_moe(
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
@ -216,7 +221,7 @@ def make_test_weight(
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
@ -228,7 +233,7 @@ def make_test_weight(
|
||||
w_gs_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
w_16[idx], None, quant_dtype, per_out_ch_quant, block_shape)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
@ -258,16 +263,16 @@ def make_test_weights(
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]],
|
||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]]:
|
||||
return (
|
||||
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
per_out_ch_quant),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
per_out_ch_quant),
|
||||
)
|
||||
|
||||
|
||||
@ -285,6 +290,76 @@ def per_token_cast_to_fp8(
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def make_test_quant_config(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
a1_gscale: Optional[torch.Tensor] = None
|
||||
a2_gscale: Optional[torch.Tensor] = None
|
||||
if quant_dtype == "nvfp4":
|
||||
a1_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a2_gscale = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||
a1_scale = a1_gscale
|
||||
a2_scale = a2_gscale
|
||||
else:
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
|
||||
return w1, w2, FusedMoEQuantConfig.make(
|
||||
quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
# TODO: make sure this is handled properly
|
||||
g1_alphas=(1 / w1_gs) if w1_gs is not None else None,
|
||||
g2_alphas=(1 / w2_gs) if w2_gs is not None else None,
|
||||
)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool = False,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score.float(), topk,
|
||||
renormalize)
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
|
||||
|
||||
@ -8,7 +8,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8)
|
||||
from vllm.platforms import current_platform
|
||||
@ -42,7 +43,8 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
|
||||
topk_ids):
|
||||
"""This function performs fused moe with per-column int8 quantization
|
||||
using native torch."""
|
||||
|
||||
@ -57,8 +59,6 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
@ -127,20 +127,27 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
|
||||
topk_weights, topk_ids)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True, # Using int8-w8a8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Check results
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
@ -36,6 +37,7 @@ __all__ = [
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"activation_without_mul",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
@ -43,7 +45,6 @@ __all__ = [
|
||||
if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
@ -56,13 +57,12 @@ if HAS_TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
TritonExperts, fused_experts, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
|
||||
@ -8,6 +8,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
@ -212,27 +214,20 @@ def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# The Deep Gemm kernels only support block size of 128
|
||||
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
|
||||
|
||||
def __init__(self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
block_shape: list[int],
|
||||
per_act_token_quant=False):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
"""
|
||||
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||
num_dispatchers: The number of DP dispatchers.
|
||||
block_shape: Block quantization block shape.
|
||||
per_act_token_quant: Per activation token quantization flag.
|
||||
quant_config: Quantization configuration
|
||||
"""
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
|
||||
super().__init__(quant_config)
|
||||
assert self.block_shape == deep_gemm_block_shape()
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@ -290,12 +285,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -321,11 +311,11 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
# may lead to better performance.
|
||||
expected_m = max_num_tokens
|
||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
|
||||
workspace1, expert_num_tokens, expected_m)
|
||||
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
workspace1, expert_num_tokens)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
|
||||
expert_num_tokens, expected_m)
|
||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
|
||||
output, expert_num_tokens, expected_m)
|
||||
|
||||
@ -8,55 +8,37 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
|
||||
|
||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
allow_deep_gemm: bool = False):
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
))
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_act_token_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8
|
||||
and self.block_shape
|
||||
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
|
||||
self.allow_deep_gemm = (allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8 and
|
||||
self.block_shape == deep_gemm_block_shape())
|
||||
|
||||
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
block_shape=self.block_shape, # type: ignore[arg-type]
|
||||
quant_config=self.quant_config,
|
||||
) if self.allow_deep_gemm else None
|
||||
|
||||
assert (self.batched_deep_gemm_experts is not None
|
||||
@ -143,12 +125,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -158,7 +135,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||
assert experts is not None
|
||||
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
activation, global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||
workspace2, expert_tokens_meta,
|
||||
activation, global_num_experts, expert_map, a1q_scale,
|
||||
workspace13, workspace2, expert_tokens_meta,
|
||||
apply_router_weight_on_input)
|
||||
|
||||
@ -1,103 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.utils import cdiv, has_triton_kernels
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if TYPE_CHECKING and has_triton_kernels:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_quant_config_quantization_args(
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prop_name: str,
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_quant_config_input_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config,
|
||||
"input_activations")
|
||||
|
||||
|
||||
def get_quant_config_weight_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||
|
||||
|
||||
def get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
def _get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
try_get_optimal_moe_config in fused_moe.py.
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w4a16"
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4"
|
||||
return "mxfp4_w4a4"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def _quant_flags_to_group_shape(
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
|
||||
"""
|
||||
Convert MoE quantization flags into more generic GroupShapes.
|
||||
"""
|
||||
a_shape: Optional[GroupShape]
|
||||
w_shape: Optional[GroupShape]
|
||||
if block_shape is not None:
|
||||
assert not per_act_token_quant
|
||||
assert not per_out_ch_quant
|
||||
# TODO(bnell): this is not quite right for activations since first
|
||||
# dim should be 1.
|
||||
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||
else:
|
||||
w_shape = None
|
||||
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
|
||||
|
||||
if per_act_token_quant:
|
||||
a_shape = GroupShape.PER_TOKEN
|
||||
|
||||
if per_out_ch_quant:
|
||||
w_shape = GroupShape.PER_TOKEN
|
||||
|
||||
return a_shape, w_shape
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
A quantization descriptor for fused MoE ops. This class can describe
|
||||
either activations or weights.
|
||||
"""
|
||||
|
||||
# The quantized type of this parameters. None means unquantized or
|
||||
# already quantized.
|
||||
# TODO (bnell): use scalar_type instead of Union.
|
||||
dtype: Union[torch.dtype, str, None] = None
|
||||
|
||||
# A field that describes the quantization group shape, from quant_utils.py.
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
# * (1, -1) for per-row quantization
|
||||
# * (-1, 1) for per-column quantization
|
||||
# * (128, 128) for 128x128 deepseek style block quantization
|
||||
# * (1, 128) for deepseek style activation quantization
|
||||
# (i.e. per-token-per-group)
|
||||
shape: Optional[GroupShape] = None
|
||||
|
||||
# Quantization scales.
|
||||
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
|
||||
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
|
||||
|
||||
# Quantization alphas or gscales, used for nvfp4 types.
|
||||
# TODO(bnell): put some of these in subclasses
|
||||
alpha_or_gscale: Optional[torch.Tensor] = None
|
||||
|
||||
# Zero points for int4/int8 types
|
||||
zp: Optional[torch.Tensor] = None
|
||||
|
||||
# Biases for GPT triton MoE
|
||||
bias: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# TODO(bnell): have subclasses for specific moe methods?
|
||||
# e.g. for specific arguments bias, precision, etc.
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
# The post quantization activation type.
|
||||
# TODO (bnell): use scalar_type instead of Union.
|
||||
quant_dtype: Union[torch.dtype, str, None] = None
|
||||
per_act_token_quant: bool = False
|
||||
per_out_ch_quant: bool = False
|
||||
block_shape: Optional[list[int]] = None
|
||||
"""
|
||||
The FusedMoEQuantConfig contains all the quantization parameters for
|
||||
a single FusedMoEMethodBase operation. It consists of four
|
||||
FusedMoEQuantDescs, one for each activation and set of weights.
|
||||
|
||||
# TODO: add col major flag?
|
||||
# add detailed quant info for input, intermediates, weights, etc?
|
||||
Each FusedMoEMethodBase must implement a get_fused_moe_quant_config
|
||||
method to construct a FusedMoEQuantConfig for use with that class.
|
||||
|
||||
FusedMoEQuant configs are only used for modular kernels, fused_experts
|
||||
(from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and
|
||||
triton_kernel_moe_forward. Other MoE methods can ignore the
|
||||
FusedMoEQuantConfig (for now) and hardcode it to None.
|
||||
|
||||
There are currently some restrictions on what can be expressed:
|
||||
- Most MoE ops only support similar quantization strategies for
|
||||
each parameter, e.g. both weights must have the same GroupShape
|
||||
and both activations must share the same GroupShape. One exception to
|
||||
this is the cutlass moe which allows per channel quantization on the
|
||||
outputs. Note: this restrictions are not always rigorously checked.
|
||||
- Not all fused MoE functions support all the parameters, e.g. zero points,
|
||||
global scales, alphas and biases are not universally supported.
|
||||
- Fully general GroupShapes are not allowed. Activations only support
|
||||
per token, per tensor or K-blocked.
|
||||
- Weights are not required to have a GroupShape since they have already
|
||||
been quantized.
|
||||
|
||||
Other notes:
|
||||
- PrecisionConfigs are specific to GPT OSS Triton.
|
||||
- As a follow up it would probably make sense to subclass FusedMoEQuantDesc
|
||||
or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses
|
||||
so that only the required quantization parameters are used/stored.
|
||||
"""
|
||||
|
||||
# TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking
|
||||
_a1: FusedMoEQuantDesc
|
||||
_a2: FusedMoEQuantDesc
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
|
||||
def __post_init__(self):
|
||||
assert (not self.per_act_token_quant
|
||||
or self.block_shape is None), "illegal quantization"
|
||||
|
||||
#
|
||||
# Convenience accessors for various properties.
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
return self._a1.dtype
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@property
|
||||
def is_per_act_token(self) -> bool:
|
||||
return self.per_act_token_quant
|
||||
return self._a1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self._w1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TENSOR
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if (self._a1.shape is not None
|
||||
and self._a1.shape != GroupShape.PER_TENSOR
|
||||
and self._a1.shape != GroupShape.PER_TOKEN):
|
||||
return [self._a1.shape.row, self._a1.shape.col]
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_block_quantized(self) -> bool:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return not self.per_act_token_quant and self.block_shape is None
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._a1.scale is None or isinstance(self._a1.scale,
|
||||
torch.Tensor)
|
||||
return self._a1.scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self._a1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._a2.scale is None or isinstance(self._a2.scale,
|
||||
torch.Tensor)
|
||||
return self._a2.scale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self._a2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale,
|
||||
torch.Tensor)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
return self._w1.zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
return self._w1.bias
|
||||
|
||||
@property
|
||||
def w1_precision(self) -> Optional["PrecisionConfig"]:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale,
|
||||
PrecisionConfig)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self._w1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale,
|
||||
torch.Tensor)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
return self._w2.zp
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
return self._w2.bias
|
||||
|
||||
@property
|
||||
def w2_precision(self) -> Optional["PrecisionConfig"]:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale,
|
||||
PrecisionConfig)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self._w2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def use_fp8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
@property
|
||||
def use_int8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.int8
|
||||
|
||||
@property
|
||||
def use_int8_w8a16(self) -> bool:
|
||||
return (self._a1.dtype is None and self._w1.dtype == torch.int8)
|
||||
|
||||
@property
|
||||
def use_int4_w4a16(self) -> bool:
|
||||
return (self._a1.dtype is None and self._w1.dtype == "int4")
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "mxfp4"
|
||||
|
||||
@property
|
||||
def use_nvfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "nvfp4"
|
||||
|
||||
def config_name(self, dtype: torch.dtype) -> Optional[str]:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
try_get_optimal_moe_config in fused_moe.py.
|
||||
"""
|
||||
return _get_config_dtype_str(
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def scale_shape(
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
"""
|
||||
Construct the proper activation scale shape for this
|
||||
config.
|
||||
"""
|
||||
if self.is_quantized:
|
||||
if self.is_block_quantized:
|
||||
assert self.block_shape is not None
|
||||
@ -117,6 +336,10 @@ class FusedMoEQuantConfig:
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int, int]]:
|
||||
"""
|
||||
Construct the proper activation batched scale shape for this
|
||||
config, e.g. (num experts, *scale_shape).
|
||||
"""
|
||||
if self.is_quantized:
|
||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||
assert scale_shape is not None
|
||||
@ -126,38 +349,218 @@ class FusedMoEQuantConfig:
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
g1_alphas: Optional[torch.Tensor] = None,
|
||||
g2_alphas: Optional[torch.Tensor] = None,
|
||||
a1_gscale: Optional[torch.Tensor] = None,
|
||||
a2_gscale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
assert sum([
|
||||
int(flag) for flag in [
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
use_mxfp4_w4a4,
|
||||
]
|
||||
]) <= 1, "Quantization flags are mutually exclusive."
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
- quant_dtype: Optional quantization type. None if activations are
|
||||
unquantized or quantized prior to calling. Note: "nvfp4" and
|
||||
"mxfp4" are the only valid string values for quant_dtype.
|
||||
- per_act_token_quant: Activations have per token quantization.
|
||||
- per_out_ch_quant: Outputs have per channel quantization. (only
|
||||
for cutlass).
|
||||
- block_shape: Optional block size for block-wise quantization.
|
||||
Incompatible with per_act_token and per_out_ch quant.
|
||||
- w1_scale: Optional scale to be used for w1.
|
||||
- w2_scale: Optional scale to be used for w2.
|
||||
- a1_scale: Optional scale to be used for a1.
|
||||
- a2_scale: Optional scale to be used for a2.
|
||||
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
||||
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
||||
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
||||
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
||||
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
"""
|
||||
assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4"
|
||||
or quant_dtype == "mxfp4")
|
||||
a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape)
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
|
||||
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
|
||||
_w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas,
|
||||
w1_zp, w1_bias),
|
||||
_w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas,
|
||||
w2_zp, w2_bias),
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
assert quant_config.block_shape == block_shape
|
||||
return quant_config
|
||||
|
||||
quant_dtype = get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
)
|
||||
return FusedMoEQuantConfig(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
def fp8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for fp8 activations and fp8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(torch.float8_e4m3fn,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def int8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for int8 activations and int8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
"mxfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
def nvfp4_moe_quant_config(
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
"nvfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
|
||||
def int4_w4a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int4 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||
_w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp),
|
||||
_w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp),
|
||||
)
|
||||
|
||||
|
||||
def int8_w8a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int8 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||
_w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp),
|
||||
_w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp),
|
||||
)
|
||||
|
||||
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: Optional[torch.Tensor],
|
||||
w2_bias: Optional[torch.Tensor],
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations with biases.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc(bias=w1_bias),
|
||||
_w2=FusedMoEQuantDesc(bias=w2_bias),
|
||||
)
|
||||
|
||||
|
||||
# A FusedMoEQuantConfig constant for an unquantized MoE op.
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -315,8 +718,6 @@ class FusedMoEConfig:
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
@ -328,34 +729,6 @@ class FusedMoEConfig:
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.quant_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.block_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_act_token_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
@ -401,97 +774,6 @@ class FusedMoEConfig:
|
||||
"""
|
||||
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
||||
"""
|
||||
return (self.quant_config is not None
|
||||
and self.quant_config.quant_dtype == "nvfp4"
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_experts: int,
|
||||
experts_per_token: int,
|
||||
hidden_dim: int,
|
||||
num_local_experts: int,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
in_dtype: torch.dtype,
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config: Optional[Union[FusedMoEQuantConfig,
|
||||
QuantizationConfig]] = None,
|
||||
has_bias: bool = False,
|
||||
) -> "FusedMoEConfig":
|
||||
|
||||
_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
if quant_config is not None and isinstance(quant_config,
|
||||
QuantizationConfig):
|
||||
if hasattr(quant_config, 'weight_block_size'):
|
||||
block_shape = quant_config.weight_block_size
|
||||
else:
|
||||
block_shape = None
|
||||
per_act_token_quant = False
|
||||
per_out_ch_quant = False
|
||||
quant_dtype: Union[torch.dtype, str, None] = None
|
||||
|
||||
input_quant = get_quant_config_input_quant(quant_config)
|
||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||
|
||||
if input_quant is not None:
|
||||
per_act_token_quant = (input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_quant is not None else False)
|
||||
|
||||
if input_quant.num_bits == 8:
|
||||
if input_quant.type == QuantizationType.FLOAT:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif input_quant.type == QuantizationType.INT:
|
||||
quant_dtype = torch.int8
|
||||
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Config)
|
||||
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||
quant_dtype = "mxfp8"
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
ModelOptNvFp4Config):
|
||||
quant_dtype = "nvfp4"
|
||||
|
||||
if weight_quant is not None:
|
||||
per_out_ch_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
if quant_dtype is not None:
|
||||
_quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
_quant_config = FusedMoEQuantConfig()
|
||||
if moe_parallel_config.dp_size > 1:
|
||||
logger.warning_once("MoE DP setup unable to determine "
|
||||
"quantization scheme or unsupported "
|
||||
"quantization type. This model will "
|
||||
"not run with DP enabled.")
|
||||
else:
|
||||
_quant_config = quant_config
|
||||
|
||||
return FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=in_dtype,
|
||||
quant_config=_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
|
||||
@ -211,21 +211,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert quant_config.use_fp8_w8a8
|
||||
super().__init__(quant_config)
|
||||
self.out_dtype = out_dtype
|
||||
self.ab_strides1 = ab_strides1
|
||||
self.ab_strides2 = ab_strides2
|
||||
@ -247,19 +240,14 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
assert self.w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert self.w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
|
||||
expert_num_tokens = None
|
||||
if expert_tokens_meta is not None:
|
||||
@ -273,9 +261,10 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, self.ab_strides1, self.ab_strides2, self.c_strides1,
|
||||
self.c_strides2, workspace13, workspace2, expert_num_tokens,
|
||||
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
|
||||
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
|
||||
self.c_strides1, self.c_strides2, workspace13, workspace2,
|
||||
expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
use_batched_format, topk_weights)
|
||||
@ -286,23 +275,19 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||
def __init__(
|
||||
self,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
@property
|
||||
@ -348,23 +333,19 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||
max_experts_per_worker: int,
|
||||
num_dispatchers: int,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
out_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
block_shape,
|
||||
quant_config,
|
||||
)
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
@ -414,16 +395,12 @@ def cutlass_moe_fp8(
|
||||
w2_q: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
per_act_token: Optional[bool] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
@ -475,10 +452,18 @@ def cutlass_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
if per_act_token is None:
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||
assert quant_config is not None
|
||||
|
||||
if quant_config.a1_scale is not None:
|
||||
assert (quant_config.per_act_token_quant ==
|
||||
quant_config.a1_scale.numel() != 1)
|
||||
if quant_config.a2_scale is not None:
|
||||
assert (quant_config.per_act_token_quant ==
|
||||
quant_config.a2_scale.numel() != 1)
|
||||
|
||||
assert (quant_config.w1_scale is None
|
||||
or (quant_config.per_out_ch_quant == (quant_config.w1_scale.size(1)
|
||||
== w1_q.size(1))))
|
||||
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
||||
0)
|
||||
@ -487,12 +472,11 @@ def cutlass_moe_fp8(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
ab_strides1=ab_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides1=c_strides1,
|
||||
c_strides2=c_strides2,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
@ -502,14 +486,9 @@ def cutlass_moe_fp8(
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
activation=activation,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -542,7 +521,7 @@ def run_cutlass_moe_fp4(
|
||||
) -> None:
|
||||
"""
|
||||
MoE implementation for FP4 Inputs
|
||||
|
||||
|
||||
# Gemm 1
|
||||
a: Input tensor: [m, k] (half/bfloat16)
|
||||
a1_gscale: Activation scale per expert: [e] (float32)
|
||||
@ -552,16 +531,16 @@ def run_cutlass_moe_fp4(
|
||||
full precision)
|
||||
w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
|
||||
(Block size = 16 for NVFP4)
|
||||
|
||||
|
||||
# Gemm 2
|
||||
a2_gscale: Activation scale per expert: [e]
|
||||
w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
|
||||
w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
|
||||
w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3
|
||||
|
||||
|
||||
topk_weights: [m, topk] dtype: float8
|
||||
topk_ids: [m, topk] dtype: float8
|
||||
|
||||
|
||||
m, n, k: Unquantized weight shapes, dtype: int
|
||||
e: number of experts, dtype: int
|
||||
|
||||
@ -652,42 +631,21 @@ def run_cutlass_moe_fp4(
|
||||
return
|
||||
|
||||
|
||||
# Split into batched and non-batched
|
||||
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
# NVFP4 requires two levels of quantization, which involves
|
||||
# computing some scaling factors dynamically. This makes it
|
||||
# incompatible with the typical prepare -> MoE -> finalize
|
||||
# pipeline. Move the quantization logic into the MoE body.
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=None, # skip quantization in prepare/finalize
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
super().__init__(quant_config)
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
# TODO(bnell): put this stuff into quant config?
|
||||
self.g1_alphas = g1_alphas
|
||||
self.g2_alphas = g2_alphas
|
||||
self.a1_gscale = a1_gscale
|
||||
self.a2_gscale = a2_gscale
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
@ -746,12 +704,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor], # unused
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -765,11 +718,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a=hidden_states,
|
||||
a1_gscale=self.a1_gscale,
|
||||
w1_fp4=w1,
|
||||
w1_blockscale=w1_scale,
|
||||
w1_blockscale=self.w1_scale,
|
||||
w1_alphas=self.g1_alphas,
|
||||
a2_gscale=self.a2_gscale,
|
||||
w2_fp4=w2,
|
||||
w2_blockscale=w2_scale,
|
||||
w2_blockscale=self.w2_scale,
|
||||
w2_alphas=self.g2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
@ -788,14 +741,9 @@ def cutlass_moe_fp4(
|
||||
a: torch.Tensor,
|
||||
w1_fp4: torch.Tensor,
|
||||
w2_fp4: torch.Tensor,
|
||||
w1_blockscale: torch.Tensor,
|
||||
w2_blockscale: torch.Tensor,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
@ -805,17 +753,31 @@ def cutlass_moe_fp4(
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"ModelOptNvFp4FusedMoE's cutlass_moe_fp4.")
|
||||
|
||||
# TODO(bnell): this feels a bit hacky
|
||||
# NVFP4 requires two levels of quantization, which involves
|
||||
# computing some scaling factors dynamically. This makes it
|
||||
# incompatible with the typical prepare -> MoE -> finalize
|
||||
# pipeline. Move the quantization logic into the MoE body.
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=None, # skip quantization in prepare/finalize
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
per_out_ch_quant=quant_config.per_out_ch_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
g1_alphas=quant_config.g1_alphas,
|
||||
g2_alphas=quant_config.g2_alphas,
|
||||
a1_gscale=quant_config.a1_gscale,
|
||||
a2_gscale=quant_config.a2_gscale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp4(
|
||||
g1_alphas,
|
||||
g2_alphas,
|
||||
a1_gscale,
|
||||
a2_gscale,
|
||||
max_experts_per_worker=e,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
quant_config=quant_config,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
@ -830,10 +792,6 @@ def cutlass_moe_fp4(
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -891,6 +849,7 @@ def _valid_cutlass_block_scaled_grouped_gemm(
|
||||
return True
|
||||
|
||||
|
||||
# TODO(bnell): would be nice combine/integrate with regular cutlass_fp8.
|
||||
def run_cutlass_block_scaled_fused_experts(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -9,9 +8,11 @@ from tqdm import tqdm
|
||||
import vllm.envs as env
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce)
|
||||
compute_aligned_M, deep_gemm_block_shape, deepgemm_moe_permute,
|
||||
deepgemm_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@ -25,14 +26,6 @@ from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def deep_gemm_block_shape() -> list[int]:
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
block = dg.get_m_alignment_for_contiguous_layout()
|
||||
return [block, block]
|
||||
|
||||
|
||||
def _valid_deep_gemm_shape(M: int, N: int, K: int) -> bool:
|
||||
align = deep_gemm_block_shape()[0]
|
||||
return align <= M and N % align == 0 and K % align == 0
|
||||
@ -163,13 +156,12 @@ def warmup_deepgemm_gg_contiguous_kernels(w1: torch.Tensor, w2: torch.Tensor,
|
||||
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=deep_gemm_block_shape(),
|
||||
))
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.block_shape == deep_gemm_block_shape()
|
||||
assert quant_config.quant_dtype == torch.float8_e4m3fn
|
||||
assert not quant_config.per_act_token_quant
|
||||
assert not quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -221,21 +213,17 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert self.block_shape is not None
|
||||
assert a1q_scale is not None
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert self.a2_scale is None
|
||||
assert self.block_shape is not None
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
@ -270,7 +258,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
aq_out=a1q_perm)
|
||||
assert a1q.size(0) == M_sum
|
||||
|
||||
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale),
|
||||
m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, self.w1_scale),
|
||||
mm1_out, expert_ids)
|
||||
|
||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||
@ -281,7 +269,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
column_major_scales=True,
|
||||
out_q=quant_out)
|
||||
|
||||
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale),
|
||||
m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, self.w2_scale),
|
||||
mm2_out, expert_ids)
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
@ -348,9 +336,16 @@ def deep_gemm_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=deep_gemm_block_shape())
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
DeepGemmExperts(),
|
||||
DeepGemmExperts(quant_config),
|
||||
)
|
||||
return fn(
|
||||
hidden_states,
|
||||
@ -358,13 +353,9 @@ def deep_gemm_moe_fp8(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -183,8 +183,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -204,7 +202,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# Quant and Dispatch
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
a1_scale,
|
||||
quant_config.a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
@ -215,7 +213,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
else:
|
||||
a1q = a1
|
||||
a1q_scale = None
|
||||
a1_post_scale = a1_scale
|
||||
a1_post_scale = quant_config.a1_scale
|
||||
|
||||
return (lambda *args: None,
|
||||
self._do_dispatch(tokens=a1q,
|
||||
@ -229,8 +227,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -238,9 +234,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
(_, receiver) = self.prepare_async(a1, a1_scale, a2_scale,
|
||||
topk_weights, topk_ids, num_experts,
|
||||
expert_map,
|
||||
(_, receiver) = self.prepare_async(a1, topk_weights, topk_ids,
|
||||
num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
return receiver()
|
||||
|
||||
@ -77,15 +77,13 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def _do_quant(
|
||||
self,
|
||||
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
block_k = block_shape[1] if block_shape is not None else None
|
||||
if self.use_fp8_dispatch:
|
||||
block_k = quant_config.block_shape[
|
||||
1] if quant_config.block_shape is not None else None
|
||||
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
||||
# DeepEP kernels did the quantization for us.
|
||||
x, x_scales = x
|
||||
@ -101,12 +99,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# TODO (varun): Optimization - Use a batched version of quant
|
||||
x = x.view((-1, hidden_dim))
|
||||
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
|
||||
per_act_token_quant,
|
||||
block_shape)
|
||||
x, x_scales = moe_kernel_quantize_input(
|
||||
x, quant_config.a1_scale, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
x = x.view((num_experts, -1, hidden_dim))
|
||||
|
||||
if quant_dtype is not None:
|
||||
if quant_config.quant_dtype is not None:
|
||||
assert x_scales is not None
|
||||
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
|
||||
|
||||
@ -118,8 +116,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -139,9 +135,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
assert hidden_size % 128 == 0, \
|
||||
"DeepEP kernels quantize the inputs in blocks of shape 128"
|
||||
|
||||
has_per_token_scales = a1_scale.numel(
|
||||
) != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
has_per_token_scales = quant_config.a1_scale.numel(
|
||||
) != 1 if quant_config.a1_scale is not None else (
|
||||
quant_config.a2_scale.numel() != 1
|
||||
if quant_config.a2_scale is not None else False)
|
||||
assert not has_per_token_scales, (
|
||||
"low_latency kernels doesn't support dispatching per-token scales")
|
||||
|
||||
@ -163,20 +160,21 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return_recv_hook=True)
|
||||
self.handles[a2a_idx] = handle
|
||||
|
||||
return (hook, lambda: self._receiver(expert_x, expert_num_tokens,
|
||||
a1_scale, a1.dtype, quant_config))
|
||||
return (
|
||||
hook,
|
||||
lambda: self._receiver(expert_x, expert_num_tokens, quant_config.
|
||||
a1_scale, a1.dtype, quant_config))
|
||||
|
||||
def _receiver(
|
||||
self,
|
||||
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
expert_num_tokens: torch.Tensor,
|
||||
a1_scale,
|
||||
a1_dtype,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
expert_x, expert_x_scale = self._do_quant(
|
||||
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype,
|
||||
quant_config)
|
||||
|
||||
expert_tokens_meta = mk.ExpertTokensMetadata(
|
||||
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
|
||||
@ -186,8 +184,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -195,8 +191,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
hook, receiver = self.prepare_async(a1, a1_scale, a2_scale,
|
||||
topk_weights, topk_ids,
|
||||
hook, receiver = self.prepare_async(a1, topk_weights, topk_ids,
|
||||
num_experts, expert_map,
|
||||
apply_router_weight_on_input,
|
||||
quant_config)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -44,33 +44,20 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
ep_rank: int = 0,
|
||||
ep_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
tp_size: int = 1,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None,
|
||||
))
|
||||
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
||||
super().__init__(quant_config)
|
||||
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
||||
"Only nvfp4,fp8 quantization are currently supported.")
|
||||
self.ep_rank = ep_rank
|
||||
self.ep_size = ep_size
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.g1_alphas = g1_alphas
|
||||
self.g2_alphas = g2_alphas
|
||||
self.a1_gscale = a1_gscale
|
||||
self.a2_gscale = a2_gscale
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
@ -141,12 +128,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor], # Not used
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -162,17 +144,17 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
fc2_expert_weights = w2
|
||||
else:
|
||||
# Ensure w1_scale and w2_scale are not None before calling view
|
||||
assert w1_scale is not None and w2_scale is not None, (
|
||||
assert self.w1_scale is not None and self.w2_scale is not None, (
|
||||
"w1_scale and w2_scale must not "
|
||||
"be None for FlashInferExperts")
|
||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||
# min because inv_scale.
|
||||
quant_scales = [
|
||||
self.a1_gscale,
|
||||
w1_scale.view(torch.int32),
|
||||
self.w1_scale.view(torch.int32),
|
||||
self.g1_alphas,
|
||||
self.a2_gscale,
|
||||
w2_scale.view(torch.int32),
|
||||
self.w2_scale.view(torch.int32),
|
||||
self.g2_alphas,
|
||||
]
|
||||
# FlashInfer API requires weight to be long for nvfp4
|
||||
@ -202,12 +184,7 @@ def flashinfer_cutlass_moe_fp4(
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
@ -216,15 +193,10 @@ def flashinfer_cutlass_moe_fp4(
|
||||
) -> torch.Tensor:
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False,
|
||||
a1_gscale=a1_gscale),
|
||||
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
|
||||
FlashInferExperts(
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_dtype="nvfp4",
|
||||
quant_config=quant_config,
|
||||
))
|
||||
|
||||
return fused_experts(
|
||||
@ -237,7 +209,5 @@ def flashinfer_cutlass_moe_fp4(
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -22,13 +22,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def __init__(
|
||||
self,
|
||||
use_dp: bool,
|
||||
a1_gscale: Optional[torch.Tensor],
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
self.use_dp = use_dp
|
||||
self.a1_gscale = a1_gscale
|
||||
self.local_tokens = None
|
||||
|
||||
@property
|
||||
@ -47,14 +45,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor], # Not used
|
||||
a2_scale: Optional[torch.Tensor], # Not used
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
# TODO(bnell): use quant_config + scales instead of ctor args
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
@ -67,7 +62,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
self.a1_gscale,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
|
||||
185
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Normal file
185
vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
Normal file
@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import List # noqa: UP035
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
calculate_tile_tokens_dim)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: List[int], #noqa: UP006
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 8
|
||||
assert topk_group <= 4
|
||||
assert global_num_experts > num_expert_group
|
||||
assert global_num_experts % num_expert_group == 0
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
||||
global_num_experts),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
mutates_args=[],
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False)
|
||||
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
||||
top_k, num_experts),
|
||||
routing_method_type=routing_method_type)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
# TODO(bnell): Does this really need to be a torch.op?
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
@ -498,8 +498,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -545,14 +543,13 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dtype=torch.float32,
|
||||
device=a1.device)
|
||||
else:
|
||||
assert a1_scale is None
|
||||
assert quant_config.a1_scale is None
|
||||
b_a1_scale = None
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
a1_scale = normalize_scales_shape(a1_scale)
|
||||
a2_scale = normalize_scales_shape(a2_scale)
|
||||
a1_scale = normalize_scales_shape(quant_config.a1_scale)
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||
@ -623,28 +620,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
assert not use_mxfp4_w4a4, "NYI"
|
||||
super().__init__(quant_config)
|
||||
assert not self.quant_config.use_int8_w8a8, "NYI"
|
||||
assert not self.quant_config.use_int8_w8a16, "NYI"
|
||||
assert not self.quant_config.use_int4_w4a16, "NYI"
|
||||
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@ -705,12 +687,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -740,10 +717,10 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
|
||||
if self.quant_config.is_quantized:
|
||||
assert a1q_scale is not None and w1_scale is not None
|
||||
assert a1q_scale is not None and self.w1_scale is not None
|
||||
input = self.dequant(hidden_states[expert, :, :],
|
||||
a1q_scale[expert])
|
||||
w1_dq = self.dequant(w1[expert], w1_scale[expert])
|
||||
w1_dq = self.dequant(w1[expert], self.w1_scale[expert])
|
||||
input = input[:num] @ w1_dq.transpose(0, 1)
|
||||
else:
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
|
||||
@ -752,8 +729,8 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.activation(activation, tmp, input.to(tmp.dtype))
|
||||
|
||||
if self.quant_config.is_quantized:
|
||||
assert w2_scale is not None
|
||||
w2_dq = self.dequant(w2[expert], w2_scale[expert])
|
||||
assert self.w2_scale is not None
|
||||
w2_dq = self.dequant(w2[expert], self.w2_scale[expert])
|
||||
else:
|
||||
w2_dq = w2[expert]
|
||||
|
||||
@ -840,35 +817,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
assert not use_mxfp4_w4a4, "NYI"
|
||||
super().__init__(quant_config)
|
||||
assert not self.quant_config.use_int8_w8a8, "NYI"
|
||||
assert not self.quant_config.use_int8_w8a16, "NYI"
|
||||
assert not self.quant_config.use_int4_w4a16, "NYI"
|
||||
assert not self.quant_config.use_mxfp4_w4a4, "NYI"
|
||||
assert max_num_tokens > 0
|
||||
assert num_dispatchers > 0
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@ -921,19 +878,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
@ -958,11 +910,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert w1.size(0) == E
|
||||
assert w2.size(0) == E
|
||||
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
config_dtype = self.quant_config.config_name(hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.size(),
|
||||
@ -992,7 +940,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(E, max_num_tokens, N // 2))
|
||||
|
||||
if self.use_fp8_w8a8:
|
||||
# TODO(bnell): should this be done for any quantized type?
|
||||
if self.quant_config.use_fp8_w8a8:
|
||||
intermediate_cache1.fill_(0)
|
||||
|
||||
a1q_scale = normalize_batched_scales_shape(a1q_scale, E)
|
||||
@ -1005,11 +954,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
compute_type=compute_type,
|
||||
A_scale=a1q_scale,
|
||||
B_scale=w1_scale,
|
||||
B_zp=w1_zp,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
B_scale=self.w1_scale,
|
||||
B_zp=self.w1_zp,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
config=config,
|
||||
per_act_token_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
@ -1021,7 +970,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||
intermediate_cache2, a2_scale, max_num_tokens, E, N,
|
||||
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
|
||||
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
|
||||
self.block_shape)
|
||||
|
||||
@ -1032,11 +981,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
compute_type=compute_type,
|
||||
A_scale=a2q_scale,
|
||||
B_scale=w2_scale,
|
||||
B_zp=w2_zp,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
B_scale=self.w2_scale,
|
||||
B_zp=self.w2_zp,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
config=config,
|
||||
per_act_token_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused MoE kernel."""
|
||||
"""Fused MoE Triton kernels."""
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
# torch.compile needs typing.List. It will fail torch.library.infer_schema
|
||||
# otherwise
|
||||
from typing import List # noqa: UP035
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -18,7 +18,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, get_config_quant_dtype)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, _get_config_dtype_str)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
_valid_cutlass_block_scaled_grouped_gemm,
|
||||
run_cutlass_block_scaled_fused_experts)
|
||||
@ -32,11 +32,7 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
calculate_tile_tokens_dim)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
_resize_cache, activation_without_mul, moe_kernel_quantize_input)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
@ -1049,87 +1045,66 @@ def fused_grouped_topk(
|
||||
return topk_values.to(torch.float32), topk_indices.to(torch.int32)
|
||||
|
||||
|
||||
def get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False,
|
||||
use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]:
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w4a16"
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4_w4a4"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
activation, is_act_and_mul,
|
||||
apply_router_weight_on_input, use_fp8_w8a8,
|
||||
activation, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
|
||||
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> None:
|
||||
def inplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None, #noqa: UP006
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@ -1143,175 +1118,6 @@ direct_register_custom_op(
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: List[int], #noqa: UP006
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
|
||||
assert top_k <= global_num_experts
|
||||
assert top_k <= 8
|
||||
assert topk_group <= 4
|
||||
assert global_num_experts > num_expert_group
|
||||
assert global_num_experts % num_expert_group == 0
|
||||
assert global_num_experts % 4 == 0
|
||||
assert top_k < (topk_group * global_num_experts / num_expert_group)
|
||||
assert block_shape == [128, 128]
|
||||
|
||||
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
|
||||
# NOTE: scales of hidden states have to be transposed!
|
||||
a_sf_t = a_sf.t().contiguous()
|
||||
return flashinfer_trtllm_fp8_block_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=a_q,
|
||||
hidden_states_scale=a_sf_t,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale_inv,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale_inv,
|
||||
num_experts=global_num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(x.shape[0], top_k,
|
||||
global_num_experts),
|
||||
routing_method_type=2, # DeepSeek-styled routing method
|
||||
use_shuffled_weight=False,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_blockscale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w13_weight: torch.Tensor,
|
||||
w13_weight_scale_inv: torch.Tensor,
|
||||
w2_weight: torch.Tensor,
|
||||
w2_weight_scale_inv: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: int,
|
||||
topk_group: int,
|
||||
intermediate_size: int,
|
||||
expert_offset: int,
|
||||
local_num_experts: int,
|
||||
block_shape: list[int],
|
||||
routed_scaling: float = 1.0) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_blockscale_fp8",
|
||||
op_func=flashinfer_fused_moe_blockscale_fp8,
|
||||
mutates_args=[],
|
||||
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 0
|
||||
topk_group = topk_group if topk_group is not None else 0
|
||||
|
||||
quant_hidden_states, _ = moe_kernel_quantize_input(
|
||||
hidden_states,
|
||||
input_scale,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False)
|
||||
|
||||
from vllm.utils.flashinfer import (
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe)
|
||||
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
|
||||
routing_logits=routing_logits,
|
||||
routing_bias=routing_bias,
|
||||
hidden_states=quant_hidden_states,
|
||||
gemm1_weights=gemm1_weights,
|
||||
output1_scales_scalar=output1_scales_scalar,
|
||||
output1_scales_gate_scalar=output1_scales_gate_scalar,
|
||||
gemm2_weights=gemm2_weights,
|
||||
output2_scales_scalar=output2_scales_scalar,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=local_expert_offset,
|
||||
local_num_experts=local_num_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
use_routing_scales_on_input=use_routing_scales_on_input,
|
||||
tile_tokens_dim=calculate_tile_tokens_dim(hidden_states.shape[0],
|
||||
top_k, num_experts),
|
||||
routing_method_type=routing_method_type)
|
||||
|
||||
|
||||
def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
|
||||
routing_logits: torch.Tensor,
|
||||
routing_bias: Optional[torch.Tensor],
|
||||
hidden_states: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
gemm1_weights: torch.Tensor,
|
||||
gemm2_weights: torch.Tensor,
|
||||
output1_scales_scalar: torch.Tensor,
|
||||
output1_scales_gate_scalar: torch.Tensor,
|
||||
output2_scales_scalar: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
intermediate_size: int,
|
||||
local_expert_offset: int,
|
||||
local_num_experts: int,
|
||||
use_routing_scales_on_input: bool,
|
||||
routing_method_type: int,
|
||||
routed_scaling_factor: float = 1.0) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
|
||||
op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
def outplace_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -1319,7 +1125,6 @@ def outplace_fused_experts(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1341,37 +1146,37 @@ def outplace_fused_experts(
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
|
||||
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
|
||||
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
|
||||
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
||||
apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4, per_channel_quant,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape, w1_bias, w2_bias)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
@ -1403,45 +1208,36 @@ def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||
|
||||
# TODO (bnell): replace this with modular op. Can get rid of inplace/outplace
|
||||
# torch ops.
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
allow_cutlass_block_scaled_grouped_gemm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
use_fp8_w8a8 = quant_config.use_fp8_w8a8
|
||||
|
||||
# For now, disable DeepGemm for small N (<= 512) until better
|
||||
# permute/unpermute ops are available.
|
||||
# However, on B200, we use DeepGemm for all cases because they only support
|
||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||
# accuracy issue.
|
||||
if (allow_deep_gemm and use_fp8_w8a8 and
|
||||
if (allow_deep_gemm and quant_config.use_fp8_w8a8 and
|
||||
(is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))):
|
||||
assert quant_config is not None
|
||||
assert apply_router_weight_on_input is False
|
||||
assert is_act_and_mul, (
|
||||
"DeepGemm only supports is_act_and_mul=True for now.")
|
||||
return deep_gemm_moe_fp8(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
@ -1452,22 +1248,23 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
a1_scale=quant_config.a1_scale,
|
||||
a2_scale=quant_config.a2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
|
||||
and _valid_cutlass_block_scaled_grouped_gemm(
|
||||
w1, w2, inplace, activation, apply_router_weight_on_input,
|
||||
expert_map)):
|
||||
assert quant_config is not None
|
||||
return run_cutlass_block_scaled_fused_experts(
|
||||
a=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids)
|
||||
else:
|
||||
@ -1478,26 +1275,49 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_channel_quant=per_channel_quant,
|
||||
use_fp8_w8a8=quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=quant_config.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=quant_config.use_mxfp4_w4a4,
|
||||
per_channel_quant=quant_config.per_act_token_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
)
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
w1_zp=quant_config.w1_zp,
|
||||
w2_zp=quant_config.w2_zp,
|
||||
a1_scale=quant_config.a1_scale,
|
||||
a2_scale=quant_config.a2_scale,
|
||||
block_shape=quant_config.block_shape,
|
||||
w1_bias=quant_config.w1_bias,
|
||||
w2_bias=quant_config.w2_bias)
|
||||
|
||||
|
||||
SILU_NO_MUL: str = activation_without_mul("silu")
|
||||
GELU_NO_MUL: str = activation_without_mul("gelu")
|
||||
|
||||
|
||||
def _get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
) -> Union[None, torch.dtype, str]:
|
||||
"""
|
||||
Get the quantization type based on the quantization strategy flags.
|
||||
We don't have a quant_config at this point so we need to work backwards.
|
||||
A return type of None means no quantization is required because the
|
||||
input is unquantized or has been quantized prior to calling
|
||||
fused_experts_impl.
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4"
|
||||
return None
|
||||
|
||||
|
||||
def fused_experts_impl(
|
||||
@ -1508,7 +1328,6 @@ def fused_experts_impl(
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@ -1557,17 +1376,18 @@ def fused_experts_impl(
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
||||
config_dtype = _get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
# Note: for use_int8_w8a16 or use_int4_w4a16, the activations are
|
||||
# quantized prior to calling fused_experts.
|
||||
quant_dtype = _get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
@ -1640,7 +1460,7 @@ def fused_experts_impl(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
quant_dtype=qtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
@ -1671,30 +1491,29 @@ def fused_experts_impl(
|
||||
B_bias=w1_bias)
|
||||
|
||||
# Activation function with multiplication
|
||||
if activation == "silu" and is_act_and_mul:
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu" and is_act_and_mul:
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "swigluoai" and is_act_and_mul:
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
torch.ops._C.swigluoai_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
# Activation function without multiplication
|
||||
elif activation == "silu":
|
||||
elif activation == SILU_NO_MUL:
|
||||
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
elif activation == GELU_NO_MUL:
|
||||
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
|
||||
f"with is_act_and_mul={is_act_and_mul}.")
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}.")
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
quant_dtype=qtype,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
@ -1726,164 +1545,13 @@ def fused_experts_impl(
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
is_act_and_mul: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- is_act_and_mul (bool): If True, use activation-and-mul function for
|
||||
activation (self-gated activation), otherwise use activation function
|
||||
for activation (ungated activation).
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and
|
||||
OCP MXFP4 activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[list[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
if not is_act_and_mul:
|
||||
assert inplace is False, (
|
||||
"is_act_and_mul=False is not supported with inplace=True")
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
|
||||
topk, renormalize,
|
||||
num_expert_group, topk_group)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize)
|
||||
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias)
|
||||
|
||||
|
||||
class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.use_mxfp4_w4a4 = use_mxfp4_w4a4
|
||||
super().__init__(quant_config)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -1929,19 +1597,14 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
if self.quant_config.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
"Hidden size mismatch")
|
||||
else:
|
||||
@ -1964,17 +1627,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
config = try_get_optimal_moe_config(
|
||||
w1.size(),
|
||||
w2.size(),
|
||||
top_k_num,
|
||||
config_dtype,
|
||||
self.quant_config.config_name(hidden_states.dtype),
|
||||
num_tokens,
|
||||
block_shape=self.block_shape,
|
||||
)
|
||||
@ -2008,8 +1665,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
self.w1_scale,
|
||||
self.w1_zp,
|
||||
None, # topk_weights
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@ -2018,13 +1675,13 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
top_k_num,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
B_bias=self.w1_bias,
|
||||
)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
@ -2033,7 +1690,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
intermediate_cache2, self.a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
@ -2041,8 +1698,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
self.w2_scale,
|
||||
self.w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@ -2051,36 +1708,21 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.quant_config.use_int8_w8a8,
|
||||
use_int8_w8a16=self.quant_config.use_int8_w8a16,
|
||||
use_int4_w4a16=self.quant_config.use_int4_w4a16,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
B_bias=None # TODO support B_bias
|
||||
B_bias=self.w2_bias,
|
||||
)
|
||||
|
||||
ops.moe_sum(intermediate_cache3, output)
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
use_mxfp4_w4a4: bool,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
quant_config: FusedMoEQuantConfig) -> mk.FusedMoEModularKernel:
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
TritonExperts(quant_config),
|
||||
)
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.utils import has_triton_kernels
|
||||
@ -23,9 +25,6 @@ if has_triton_kernels():
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible.")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
@ -35,20 +34,10 @@ def triton_kernel_moe_forward(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
activation: str = "silu",
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_precision: Optional["PrecisionConfig"] = None,
|
||||
w2_precision: Optional["PrecisionConfig"] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
routing_data, gather_idx, scatter_idx = routing(gating_output,
|
||||
@ -64,20 +53,10 @@ def triton_kernel_moe_forward(
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
@ -90,28 +69,23 @@ def triton_kernel_fused_experts(
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_precision: Optional["PrecisionConfig"] = None,
|
||||
w2_precision: Optional["PrecisionConfig"] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
a1q_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert w1_bias is None or w1_bias.dtype == torch.float32
|
||||
assert w2_bias is None or w2_bias.dtype == torch.float32
|
||||
assert (quant_config.w1_bias is None
|
||||
or quant_config.w1_bias.dtype == torch.float32)
|
||||
assert (quant_config.w2_bias is None
|
||||
or quant_config.w2_bias.dtype == torch.float32)
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
@ -130,20 +104,20 @@ def triton_kernel_fused_experts(
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
w1_bias,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=w1_precision,
|
||||
precision_config=quant_config.w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=act)
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
w2,
|
||||
w2_bias,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=w2_precision,
|
||||
precision_config=quant_config.w2_precision,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=output_tensor,
|
||||
)
|
||||
@ -154,21 +128,13 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
w1_precision: "PrecisionConfig",
|
||||
w2_precision: "PrecisionConfig",
|
||||
w1_bias: Optional[torch.Tensor],
|
||||
w2_bias: Optional[torch.Tensor],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
self.w1_precision = w1_precision
|
||||
self.w2_precision = w2_precision
|
||||
self.w1_bias = w1_bias
|
||||
self.w2_bias = w2_bias
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -212,12 +178,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -228,20 +189,12 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
routing_data=None,
|
||||
gather_indx=None,
|
||||
scatter_indx=None,
|
||||
activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
apply_router_weight_on_input=False,
|
||||
use_fp8_w8a8=False,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=self.w1_bias,
|
||||
w2_bias=self.w2_bias,
|
||||
w1_precision=self.w1_precision,
|
||||
w2_precision=self.w2_precision,
|
||||
a1_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
a1q_scale=a1q_scale)
|
||||
|
||||
@ -22,7 +22,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEConfig, FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig, biased_moe_quant_config)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||
@ -78,11 +79,11 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
# TODO(bnell): also pass quant_config?
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.moe = moe
|
||||
self.fused_experts: Optional[Callable] = None
|
||||
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
self.fused_experts: Optional[FusedMoEModularKernel] = None
|
||||
self.topk_indices_dtype = None
|
||||
|
||||
@abstractmethod
|
||||
@ -103,23 +104,28 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
@staticmethod
|
||||
def _maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
|
||||
# TODO: could allow this now
|
||||
assert not moe.use_flashinfer_cutlass_kernels, \
|
||||
"Must be created in modelopt.py"
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
moe.quant_dtype,
|
||||
per_act_token_quant=moe.per_act_token_quant,
|
||||
block_shape=moe.block_shape,
|
||||
quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
@ -165,6 +171,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
assert quant_config is not None
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
@ -174,13 +181,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
all2all_manager.world_size)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||
# reduce datamovement
|
||||
use_fp8_dispatch = (moe.quant_config is not None
|
||||
and moe.quant_config.quant_dtype
|
||||
== current_platform.fp8_dtype()
|
||||
and moe.quant_config.block_shape
|
||||
== DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
@ -192,11 +197,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
return prepare_finalize
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if moe.moe_parallel_config.use_all2all_kernels:
|
||||
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if self.moe.moe_parallel_config.use_all2all_kernels:
|
||||
return FusedMoEMethodBase._maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -204,7 +208,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# prepare_communication_buffer_for_model.
|
||||
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||
assert self.moe is not None
|
||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||
|
||||
# We must get the quant config here so that the layer is
|
||||
# completely initialized, i.e. all weights loaded and post
|
||||
# processed.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
|
||||
prepare_finalize = self.maybe_make_prepare_finalize()
|
||||
|
||||
if prepare_finalize is not None:
|
||||
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
|
||||
@ -213,7 +223,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
assert self.fused_experts is None, \
|
||||
f"Attempt to override experts for {id(self)}!"
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
|
||||
experts = self.select_gemm_impl(prepare_finalize, layer)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
@ -223,7 +233,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
@ -232,6 +241,11 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize")
|
||||
|
||||
@abstractmethod
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
@ -265,7 +279,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||
@ -273,23 +286,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
# TODO(bnell): Remove. Every layer should have an moe config object.
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
return TritonExperts()
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -303,7 +323,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
@ -320,7 +340,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
if self.has_bias:
|
||||
if self.moe.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
@ -442,6 +462,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if self.moe.has_bias:
|
||||
return biased_moe_quant_config(
|
||||
layer.w13_bias,
|
||||
layer.w2_bias,
|
||||
)
|
||||
else:
|
||||
return FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -486,6 +516,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert self.fused_experts is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -496,7 +527,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
elif self.fused_experts is not None:
|
||||
if self.has_bias:
|
||||
if self.moe.has_bias:
|
||||
raise ValueError(
|
||||
"FusedMoEModularKernel does not support bias.")
|
||||
return self.fused_experts(
|
||||
@ -517,12 +548,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@ -933,16 +963,18 @@ class FusedMoE(CustomOp):
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = FusedMoEConfig.make(num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
has_bias=has_bias)
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
has_bias=has_bias,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
@ -990,6 +1022,9 @@ class FusedMoE(CustomOp):
|
||||
# Chunked all2all staging tensor
|
||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
|
||||
# TODO(bnell): flashinfer uses non-batched format.
|
||||
# Does it really need a batched buffer?
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or self.moe_config.use_flashinfer_cutlass_kernels):
|
||||
@ -1062,7 +1097,9 @@ class FusedMoE(CustomOp):
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
return self.moe_config.use_flashinfer_cutlass_kernels
|
||||
return (self.moe_quant_config is not None
|
||||
and self.moe_quant_config.quant_dtype == "nvfp4"
|
||||
and self.moe_config.use_flashinfer_cutlass_kernels)
|
||||
|
||||
def update_expert_map(self):
|
||||
# ep_size and ep_rank should already be updated
|
||||
@ -1492,6 +1529,11 @@ class FusedMoE(CustomOp):
|
||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
|
||||
def ensure_moe_quant_config(self):
|
||||
if self.quant_method.moe_quant_config is None:
|
||||
self.quant_method.moe_quant_config = (
|
||||
self.quant_method.get_fused_moe_quant_config(self))
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1711,6 +1753,8 @@ class FusedMoE(CustomOp):
|
||||
assert (
|
||||
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
|
||||
if self.shared_experts is not None:
|
||||
full_shared_final_hidden_states = torch.empty_like(
|
||||
@ -1825,14 +1869,17 @@ class FusedMoE(CustomOp):
|
||||
router_logits: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.ensure_moe_quant_config()
|
||||
|
||||
# Route to the chunked forward path using the FlashInfer Cutlass kernel
|
||||
# only when data parallelism (DP) is enabled.
|
||||
use_flashinfer_cutlass_kernels = (
|
||||
self.dp_size > 1
|
||||
and self.moe_config.use_flashinfer_cutlass_kernels)
|
||||
_use_flashinfer_cutlass_kernels = (self.dp_size > 1 and
|
||||
self.use_flashinfer_cutlass_kernels)
|
||||
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels
|
||||
or use_flashinfer_cutlass_kernels):
|
||||
or _use_flashinfer_cutlass_kernels):
|
||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
||||
|
||||
do_naive_dispatch_combine: bool = (
|
||||
|
||||
@ -177,8 +177,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -189,9 +187,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
Perform any quantization (and/or) dispatching needed for this kernel.
|
||||
- a1: The (unquantized) input to the MoE layer.
|
||||
- a1_scale: Optional scales for a1
|
||||
- a2_scale: Optional scales for the second MoE gemm. Required to make
|
||||
sure the quantization is consistent for both gemms.
|
||||
- topk_ids: The topk ids.
|
||||
- topk_weights: The topk weights.
|
||||
- num_experts: The total number of experts in the global expert space.
|
||||
@ -199,10 +194,11 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
space to the local expert space of the expert parallel shard.
|
||||
- apply_router_weight_on_input: When True, apply the weights to the
|
||||
activations, before quantization + dispatching.
|
||||
- quant_config: Quantization info provided by the fused experts.
|
||||
|
||||
Returns a tuple of:
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional quantized + dispatched a1_scales.
|
||||
- Optional ExpertTokensMetadata containing gpu/cpu tensors
|
||||
as big as the number of local experts with the information about the
|
||||
number of tokens assigned to each local expert.
|
||||
@ -220,8 +216,6 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -316,6 +310,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# TODO: add supported activations method (return string)
|
||||
class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
An abstract base class for the [Permute-Experts-Unpermute] step described
|
||||
@ -324,12 +319,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
if quant_config is not None:
|
||||
self.quant_config = quant_config
|
||||
else:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
"""
|
||||
quant_config: Quantization parameters for this experts instance.
|
||||
"""
|
||||
self.quant_config = quant_config
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -341,6 +336,11 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
#
|
||||
# Various helpers for accessing quantization parameters from the
|
||||
# quant_config.
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
return self.quant_config.quant_dtype
|
||||
@ -357,6 +357,54 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a1_scale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a2_scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a1_gscale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.a2_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w1_scale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w2_scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w1_zp
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w2_zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w1_bias
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.w2_bias
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.g1_alphas
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self.quant_config.g2_alphas
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
@ -433,12 +481,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
@ -455,7 +498,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- topk_weights: A map of row to expert weights. Some implementations
|
||||
choose to do weight application.
|
||||
choose to do weight application.
|
||||
- topk_ids (torch.Tensor): A map of row to expert id.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
@ -464,15 +507,9 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w1.
|
||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w2.
|
||||
- a1q_scale (Optional[torch.Tensor]): Optional quantized scale to be
|
||||
used for a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||
used for a1. Result of quantization from prepare/finalize and not
|
||||
from the FusedMoEQuantConfig.
|
||||
- workspace13 (torch.Tensor): A scratch tensor used for gemm outputs
|
||||
must be large enough to hold output of either MoE gemm.
|
||||
- workspace2 (torch.Tensor): A scratch tensor used for the activation
|
||||
@ -559,12 +596,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@ -601,12 +633,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
@ -627,12 +654,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@ -658,12 +680,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@ -685,9 +702,13 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s,
|
||||
e), topk_ids[s:e], topk_weights[s:e])
|
||||
return (
|
||||
a1q[s:e],
|
||||
_chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(self.fused_experts.a2_scale, s, e),
|
||||
topk_ids[s:e],
|
||||
topk_weights[s:e],
|
||||
)
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
@ -744,12 +765,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@ -767,12 +783,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
@ -795,14 +805,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2.
|
||||
- w1_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w1.
|
||||
- w2_zp (Optional[torch.Tensor]): Optional zero points to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for a2.
|
||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||
applied directly on the inputs. This is only applicable when topk is
|
||||
1.
|
||||
@ -832,8 +834,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
@ -846,8 +846,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
dbo_maybe_run_recv_hook()
|
||||
hook, receiver = self.prepare_finalize.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
@ -897,12 +895,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -95,8 +95,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare_async(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -130,8 +128,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
repeat_cols = 4
|
||||
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
||||
# TODO(bnell): always pass quant_config.a1_scale?
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, (None if quant_config.per_act_token_quant else a1_scale),
|
||||
a1, (None if quant_config.per_act_token_quant else
|
||||
quant_config.a1_scale),
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape)
|
||||
@ -253,8 +253,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -264,8 +262,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
) -> mk.PrepareResultType:
|
||||
hook, receiver = self.prepare_async(
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_experts,
|
||||
|
||||
@ -30,8 +30,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
@ -48,7 +46,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, a1_scale, quant_config.quant_dtype,
|
||||
a1, quant_config.a1_scale, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
@ -7,6 +7,8 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -305,21 +307,18 @@ def rocm_aiter_grouped_topk(
|
||||
|
||||
|
||||
def rocm_aiter_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
activation_method = (ActivationMethod.SILU
|
||||
if activation == "silu" else ActivationMethod.GELU)
|
||||
@ -333,7 +332,8 @@ def rocm_aiter_fused_experts(
|
||||
expert_mask = None
|
||||
|
||||
# w8a8 per-channel quantization
|
||||
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||
if (quant_config.per_act_token_quant and apply_router_weight_on_input
|
||||
and quant_config.use_fp8_w8a8):
|
||||
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||
# This applies topk_weights on the GEMM output of the first FC layer
|
||||
# rather than the second FC.
|
||||
@ -349,8 +349,8 @@ def rocm_aiter_fused_experts(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=w1_scale,
|
||||
fc2_scale=w2_scale,
|
||||
fc1_scale=quant_config.w1_scale,
|
||||
fc2_scale=quant_config.w2_scale,
|
||||
fc1_smooth_scale=None,
|
||||
fc2_smooth_scale=None,
|
||||
a16=False,
|
||||
@ -362,14 +362,14 @@ def rocm_aiter_fused_experts(
|
||||
quant_method = QuantMethod.NO.value
|
||||
|
||||
# w8a8 block-scaled
|
||||
if block_shape is not None and use_fp8_w8a8:
|
||||
if quant_config.block_shape is not None and quant_config.use_fp8_w8a8:
|
||||
assert not apply_router_weight_on_input, (
|
||||
"apply_router_weight_on_input is\
|
||||
not supported for block scaled moe")
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert quant_config.w1_scale is not None
|
||||
assert quant_config.w2_scale is not None
|
||||
quant_method = QuantMethod.BLOCK_128x128.value
|
||||
elif use_fp8_w8a8:
|
||||
elif quant_config.use_fp8_w8a8:
|
||||
# Currently only per tensor quantization method is enabled.
|
||||
quant_method = QuantMethod.PER_TENSOR.value
|
||||
|
||||
@ -390,10 +390,10 @@ def rocm_aiter_fused_experts(
|
||||
expert_mask=expert_mask,
|
||||
quant_method=quant_method,
|
||||
activation_method=activation_method,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_scale=quant_config.w1_scale,
|
||||
w2_scale=quant_config.w2_scale,
|
||||
a1_scale=quant_config.a1_scale,
|
||||
a2_scale=quant_config.a2_scale,
|
||||
doweight_stage1=apply_router_weight_on_input)
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape,
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
@ -17,40 +18,19 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
self.triton_expert = TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_mxfp4_w4a4=use_mxfp4_w4a4,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
super().__init__(quant_config)
|
||||
|
||||
self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and
|
||||
self.triton_expert = TritonExperts(quant_config)
|
||||
|
||||
self.allow_deep_gemm = (allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8 and
|
||||
self.block_shape == deep_gemm_block_shape())
|
||||
|
||||
self.deep_gemm_expert = DeepGemmExperts(
|
||||
) if self.allow_deep_gemm else None
|
||||
self.quant_config) if self.allow_deep_gemm else None
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
@ -130,12 +110,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -158,12 +133,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
|
||||
@ -5,7 +5,8 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.utils import next_power_of_2
|
||||
@ -16,20 +17,17 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
gemm1_alpha,
|
||||
gemm1_beta,
|
||||
gemm1_clamp_limit,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
max_capture_size,
|
||||
):
|
||||
super().__init__(moe.quant_config)
|
||||
super().__init__(quant_config)
|
||||
self.moe = moe
|
||||
self.gemm1_alpha = gemm1_alpha
|
||||
self.gemm1_beta = gemm1_beta
|
||||
self.gemm1_clamp_limit = gemm1_clamp_limit
|
||||
self.w13_bias = w13_bias
|
||||
self.w2_bias = w2_bias
|
||||
self.max_capture_size = max_capture_size
|
||||
|
||||
@property
|
||||
@ -104,12 +102,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -129,8 +122,8 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16).view(torch.int16)
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
kwargs = {
|
||||
"topk_ids":
|
||||
packed_tensor,
|
||||
@ -143,9 +136,9 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"gemm1_weights":
|
||||
w1,
|
||||
"gemm1_weights_scale":
|
||||
w1_scale,
|
||||
self.w1_scale,
|
||||
"gemm1_bias":
|
||||
self.w13_bias,
|
||||
self.w1_bias,
|
||||
"gemm1_alpha":
|
||||
self.gemm1_alpha,
|
||||
"gemm1_beta":
|
||||
@ -155,7 +148,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"gemm2_weights":
|
||||
w2,
|
||||
"gemm2_weights_scale":
|
||||
w2_scale,
|
||||
self.w2_scale,
|
||||
"gemm2_bias":
|
||||
self.w2_bias,
|
||||
"output1_scale_scalar":
|
||||
|
||||
@ -268,3 +268,7 @@ def _validate_scale_shape(
|
||||
assert block_shape is not None
|
||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||
|
||||
|
||||
def activation_without_mul(activation: str) -> str:
|
||||
return activation + "_no_mul"
|
||||
|
||||
@ -9,8 +9,10 @@ from torch.nn import Parameter
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
@ -483,6 +485,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -6,8 +6,9 @@ from typing import Any, Callable, Optional, Union
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
@ -452,6 +453,10 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -509,6 +514,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
def _create_weights_4bit(
|
||||
|
||||
@ -16,8 +16,11 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported)
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
@ -122,7 +125,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||
quant_config, layer.moe_config)
|
||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
|
||||
return CompressedTensorsW4A4MoeMethod(layer.moe_config)
|
||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
||||
@ -138,7 +141,7 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||
detect_nvfp4_moe_support)
|
||||
super().__init__(moe)
|
||||
@ -147,7 +150,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||
self.use_marlin = _nvfp4.use_marlin
|
||||
self.group_size = 16
|
||||
self.layer = layer
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -305,37 +307,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
(layer.w2_input_global_scale), requires_grad=False)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if not self.allow_flashinfer:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
elif not self.allow_flashinfer:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
)
|
||||
self.moe)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
"""Return the appropriate GEMM experts implementation."""
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
moe,
|
||||
g1_alphas=self.layer.g1_alphas,
|
||||
g2_alphas=self.layer.g2_alphas,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
a2_gscale=self.layer.w2_input_scale_quant,
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
allow_flashinfer=self.allow_flashinfer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -359,8 +370,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB not supported for "
|
||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||
@ -381,7 +390,12 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin.
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -401,8 +415,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace)
|
||||
|
||||
# FlashInfer fused experts path
|
||||
if self.fused_experts is not None:
|
||||
elif self.fused_experts is not None:
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
@ -417,11 +430,10 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# FlashInfer fused experts path
|
||||
elif self.allow_flashinfer:
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4)
|
||||
@ -430,51 +442,46 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4MoeMethod.")
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||
"is currently not supported for "
|
||||
"CompressedTensorsW4A4MoeMethod.")
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w1_blockscale=layer.w13_weight_scale,
|
||||
w2_blockscale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||
x.dtype)
|
||||
# Cutlass moe takes in activations in BF16/Half precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# TODO(bnell): derive these from arguments
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
).to(x.dtype)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
@ -692,16 +699,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||
elif self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
self.fused_experts_func = None
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
if self.use_cutlass:
|
||||
device = layer.w13_weight.device
|
||||
@ -722,11 +724,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# cutlass path
|
||||
assert self.moe_quant_config is not None
|
||||
if self.use_cutlass:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
|
||||
@ -740,26 +751,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
logger.debug("CutlassBatchedExpertsFp8(%s)",
|
||||
self.__class__.__name__)
|
||||
experts = CutlassBatchedExpertsFp8(
|
||||
moe.num_local_experts,
|
||||
self.moe.num_local_experts,
|
||||
num_dispatchers,
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
self.moe.in_dtype,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||
experts = CutlassExpertsFp8(
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
self.moe.in_dtype,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
self.disable_expert_map = (num_dispatchers > 1
|
||||
@ -774,29 +783,40 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank(
|
||||
)
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
per_act_token_quant=(
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
return TritonExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
per_act_token_quant=(
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN),
|
||||
)
|
||||
logger.debug("TritonExperts(%s)", self.__class__.__name__)
|
||||
return TritonExperts(self.moe_quant_config)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
|
||||
per_act_token = (
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
per_channel_quant = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_channel_quant,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -841,92 +861,19 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
# cutlass path
|
||||
if self.use_cutlass:
|
||||
per_act_token = (
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
per_channel_quant = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
per_act_token = (
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN)
|
||||
per_channel_quant = (
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
# small-batch fallback on SM100
|
||||
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=per_channel_quant,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
|
||||
if self.fused_experts is None:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8)
|
||||
return cutlass_moe_fp8(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token=per_act_token,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return self.rocm_aiter_fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy ==
|
||||
QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
expert_map=expert_map)
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# cutlass fp8 or fused_experts but not marlin or rocm.
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -944,26 +891,95 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace)
|
||||
|
||||
assert self.fused_experts_func is not None
|
||||
elif self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts)
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
assert self.fused_experts is None
|
||||
return rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
return self.fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy ==
|
||||
QuantizationStrategy.CHANNEL,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
elif self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
)
|
||||
|
||||
# cutlass path
|
||||
elif self.use_cutlass:
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
# small-batch fallback on SM100
|
||||
if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert per_act_token == per_channel_quant
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8)
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
return cutlass_moe_fp8(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=None if self.disable_expert_map else expert_map,
|
||||
ab_strides1=self.ab_strides1_c_strides2,
|
||||
ab_strides2=self.ab_strides2,
|
||||
c_strides1=self.c_strides1,
|
||||
c_strides2=self.ab_strides1_c_strides2,
|
||||
)
|
||||
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert per_act_token == per_channel_quant
|
||||
assert self.moe_quant_config is not None
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
@ -1049,6 +1065,16 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return int8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=True,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1104,14 +1130,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
@ -1355,6 +1377,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1588,6 +1614,20 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight_scale.transpose(1, 2).contiguous(),
|
||||
requires_grad=False)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
assert self.num_bits == 4 or self.num_bits == 8
|
||||
config_builder = (int4_w4a16_moe_quant_config if self.num_bits == 4
|
||||
else int8_w8a16_moe_quant_config)
|
||||
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, self.group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1641,13 +1681,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int4_w4a16=self.num_bits == 4,
|
||||
use_int8_w8a16=self.num_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, self.group_size])
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@ -8,6 +8,8 @@ import torch
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -106,6 +108,13 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return int8_w8a16_moe_quant_config(w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -159,12 +168,11 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int8_w8a16=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale)
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def quantizing_weight_loader(layer, weight_loader):
|
||||
|
||||
@ -14,9 +14,11 @@ from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -575,20 +577,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
layer=self.layer,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
@ -928,10 +916,23 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor(
|
||||
layer.w2_weight_scale_inv)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if (self.rocm_aiter_moe_enabled or self.use_marlin
|
||||
or self.flashinfer_moe_backend
|
||||
== FlashinferMoeBackend.TENSORRT_LLM):
|
||||
return None
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
prepare_finalize = (
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
@ -940,6 +941,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
max_num_tokens_per_rank = (
|
||||
@ -953,15 +956,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
per_act_token_quant=False,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
self.layer,
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
@ -971,11 +972,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.__class__.__name__, self.quant_config.weight_block_size,
|
||||
False)
|
||||
return TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1005,12 +1020,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if (self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
and self.fused_experts is None):
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
if self.block_quant:
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
|
||||
@ -1066,9 +1083,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
#
|
||||
# Note: the order of checks is important since self.fused_experts
|
||||
# can override fused_experts or cutlass but not rocm or marlin.
|
||||
#
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_fused_experts)
|
||||
assert self.fused_experts is None
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1076,19 +1098,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
expert_map=expert_map)
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config)
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1105,40 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert self.block_quant is None
|
||||
assert (not renormalize and custom_routing_function is not None)
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
common_kwargs = dict(
|
||||
elif self.fused_experts:
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -1149,26 +1133,43 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=(layer.w13_weight_scale_inv
|
||||
if self.block_quant else layer.w13_weight_scale),
|
||||
w2_scale=(layer.w2_weight_scale_inv
|
||||
if self.block_quant else layer.w2_weight_scale),
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert self.block_quant is None
|
||||
assert (not renormalize and custom_routing_function is not None)
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert scoring_func == 'sigmoid', (
|
||||
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(**common_kwargs)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
**common_kwargs,
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm),
|
||||
)
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
allow_cutlass_block_scaled_grouped_gemm=(
|
||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||
|
||||
@ -10,8 +10,9 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -518,6 +519,10 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
||||
set_weight_attrs(w2_qweight_type, extra_weight_attrs)
|
||||
layer.register_parameter("w2_qweight_type", w2_qweight_type)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -9,8 +9,10 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||
UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
@ -632,6 +634,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -375,6 +376,10 @@ class XPUFp8MoEMethod(FusedMoEMethodBase):
|
||||
use_prepack=True,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -11,7 +11,9 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
@ -294,8 +296,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
cutlass_fp8_supported)
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
@ -303,29 +303,27 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if self.fused_experts is not None or \
|
||||
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
layer=self.layer,
|
||||
)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
self, ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
# TRT LLM not supported with all2all yet.
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
return None
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
prepare_finalize = (
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(self.moe))
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
self.layer,
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
return experts
|
||||
@ -479,6 +477,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
return None
|
||||
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -507,6 +518,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
assert self.fused_experts is None
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
assert not renormalize
|
||||
@ -537,55 +549,56 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# cutlass or fused_experts.
|
||||
#
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||
assert not renormalize
|
||||
assert activation == 'silu', (
|
||||
f"Expected 'silu' activation but got {activation}")
|
||||
if self.fused_experts is not None:
|
||||
return self.fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
return flashinfer_cutlass_moe_fp8(
|
||||
x,
|
||||
layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
|
||||
class ModelOptNvFp4Config(QuantizationConfig):
|
||||
@ -1034,33 +1047,30 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
" for ModelOptNvFp4FusedMoE.")
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if (self.allow_flashinfer and self.flashinfer_moe_backend
|
||||
== FlashinferMoeBackend.CUTLASS):
|
||||
self) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||
if (self.use_marlin
|
||||
or (self.allow_flashinfer and self.flashinfer_moe_backend
|
||||
== FlashinferMoeBackend.TENSORRT_LLM)):
|
||||
return None
|
||||
elif (self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
|
||||
# For now, fp4 moe only works with the flashinfer dispatcher.
|
||||
prepare_finalize = (
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
))
|
||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe))
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
|
||||
return super().maybe_make_prepare_finalize(moe)
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
assert self.moe_quant_config is not None
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
moe,
|
||||
g1_alphas=self.layer.g1_alphas,
|
||||
g2_alphas=self.layer.g2_alphas,
|
||||
a1_gscale=self.layer.w13_input_scale_quant,
|
||||
a2_gscale=self.layer.w2_input_scale_quant,
|
||||
self.moe,
|
||||
self.moe_quant_config,
|
||||
allow_flashinfer=self.allow_flashinfer,
|
||||
)
|
||||
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||
@ -1360,6 +1370,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
layer.w2_weight = Parameter(layer.w2_weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
if (self.use_marlin or self.flashinfer_moe_backend
|
||||
== FlashinferMoeBackend.TENSORRT_LLM):
|
||||
return None
|
||||
|
||||
return nvfp4_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -1388,12 +1413,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet.")
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
if self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
if (self.allow_flashinfer and self.flashinfer_moe_backend
|
||||
== FlashinferMoeBackend.TENSORRT_LLM):
|
||||
import flashinfer
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
assert self.fused_experts is None
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4,
|
||||
hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
@ -1457,7 +1484,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
#
|
||||
# Note: the order here is important. self.fused_experts can override
|
||||
# flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or
|
||||
# trtllm.
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@ -1477,7 +1510,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
expert_map=expert_map,
|
||||
workspace=layer.workspace)
|
||||
|
||||
if self.fused_experts is not None:
|
||||
elif self.fused_experts is not None:
|
||||
assert self.allow_flashinfer and \
|
||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||
|
||||
@ -1485,7 +1518,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
x, layer.w13_weight, layer.w2_weight), (
|
||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||
|
||||
out = self.fused_experts(
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@ -1495,28 +1528,22 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
elif (self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS):
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
flashinfer_cutlass_moe_fp4)
|
||||
assert self.moe_quant_config is not None
|
||||
|
||||
out = flashinfer_cutlass_moe_fp4(
|
||||
return flashinfer_cutlass_moe_fp4(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||
quant_config=self.moe_quant_config,
|
||||
inplace=False,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@ -1527,23 +1554,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
# only (no EP).
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4)
|
||||
out = cutlass_moe_fp4(
|
||||
assert self.moe_quant_config is not None
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
w1_fp4=layer.w13_weight,
|
||||
w2_fp4=layer.w2_weight,
|
||||
w1_blockscale=layer.w13_weight_scale,
|
||||
w2_blockscale=layer.w2_weight_scale,
|
||||
g1_alphas=layer.g1_alphas,
|
||||
g2_alphas=layer.g2_alphas,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
a2_gscale=layer.w2_input_scale_quant,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
# TODO: derive from arguments
|
||||
m=x.shape[0],
|
||||
n=layer.w2_weight.shape[2] * 2,
|
||||
k=x.shape[1],
|
||||
e=layer.w13_weight.shape[0],
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
return out
|
||||
)
|
||||
|
||||
@ -6,6 +6,9 @@ from typing import Any, Callable, Optional, Union
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
@ -283,6 +286,22 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -327,9 +346,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
@ -337,16 +353,11 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size])
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
@ -12,6 +12,8 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, mxfp4_w4a4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -629,10 +631,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
return None
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = layer.w13_precision_config
|
||||
w2_scale = layer.w2_precision_config
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
@ -647,11 +668,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
"gemm1_beta": layer.gemm1_beta,
|
||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||
"w13_bias": layer.w13_bias,
|
||||
"w2_bias": layer.w2_bias,
|
||||
# TODO(bnell): part of quant_config
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
return TrtLlmGenExperts(moe, **kwargs)
|
||||
assert self.moe_quant_config is not None
|
||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config,
|
||||
**kwargs)
|
||||
else:
|
||||
# Use matmul_ogs from triton_kernels here!
|
||||
raise NotImplementedError(
|
||||
@ -710,8 +732,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -941,10 +961,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_precision=self.w13_precision_config,
|
||||
w2_precision=self.w2_precision_config,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -11,6 +11,9 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
mxfp4_w4a4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@ -287,6 +290,16 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=self.weight_qscheme == "per_channel",
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -339,12 +352,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_qscheme == "per_channel",
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map)
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
@ -376,14 +384,9 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_qscheme == "per_channel",
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
quant_config=self.moe_quant_config)
|
||||
|
||||
|
||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
@ -487,6 +490,16 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -539,15 +552,10 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_mxfp4_w4a4=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
activation=activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -12,6 +12,9 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -269,6 +272,21 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
fix_weights(layer, "w13_weight", weight_bits == 4)
|
||||
fix_weights(layer, "w2_weight", weight_bits == 4)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
assert weight_bits == 4 or weight_bits == 8
|
||||
config_builder = (int4_w4a16_moe_quant_config
|
||||
if weight_bits == 4 else int8_w8a16_moe_quant_config)
|
||||
return config_builder(
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
w1_zp=None,
|
||||
w2_zp=None,
|
||||
block_shape=[0, group_size],
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -314,10 +332,7 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
ret = fused_experts(
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@ -325,16 +340,11 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=global_num_experts,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
block_shape=[0, group_size])
|
||||
|
||||
return ret
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
@ -47,32 +48,23 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
|
||||
|
||||
|
||||
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
a1_gscale: torch.Tensor,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||
|
||||
|
||||
def select_nvfp4_gemm_impl(
|
||||
moe: FusedMoEConfig,
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
allow_flashinfer: bool,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||
|
||||
if allow_flashinfer:
|
||||
return FlashInferExperts(
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_dtype="nvfp4",
|
||||
quant_config=moe_quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
|
||||
@ -8,7 +8,8 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
@ -99,6 +100,8 @@ def apply_flashinfer_per_tensor_scale_fp8(
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer.fused_moe import RoutingMethodType
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
"Expected output1_scales_scalar to be initialized")
|
||||
assert layer.output1_scales_scalar is not None, (
|
||||
@ -167,34 +170,23 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
||||
|
||||
|
||||
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
use_dp, a1_gscale=layer.w13_input_scale)
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
moe: Optional[FusedMoEConfig],
|
||||
layer: torch.nn.Module,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||
|
||||
if moe is not None:
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=moe.in_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_config=quant_config,
|
||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||
ep_size=moe.moe_parallel_config.ep_size,
|
||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||
@ -204,12 +196,8 @@ def select_cutlass_fp8_gemm_impl(
|
||||
assert out_dtype is not None, (
|
||||
"If moe config is None, out_dtype must be passed")
|
||||
return FlashInferExperts(
|
||||
g1_alphas=layer.output1_scales_gate_scalar,
|
||||
g2_alphas=layer.output2_scales_scalar,
|
||||
a1_gscale=layer.w13_input_scale,
|
||||
a2_gscale=layer.w2_input_scale_inv,
|
||||
out_dtype=out_dtype,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
|
||||
@ -224,11 +212,13 @@ def flashinfer_cutlass_moe_fp8(
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
|
||||
assert quant_config is not None
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
|
||||
layer=layer),
|
||||
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None),
|
||||
select_cutlass_fp8_gemm_impl(moe=None,
|
||||
layer=layer,
|
||||
quant_config=quant_config,
|
||||
out_dtype=hidden_states.dtype))
|
||||
|
||||
return fused_experts(
|
||||
|
||||
@ -411,6 +411,7 @@ def per_token_group_quant_fp8(
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
# prefer CUDA kernel if available
|
||||
# TODO(bnell): this causes some fp8 moe test to fail.
|
||||
if current_platform.is_cuda() and x.is_contiguous():
|
||||
torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps,
|
||||
fp8_min, fp8_max, use_ue8m0)
|
||||
|
||||
@ -15,8 +15,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
|
||||
get_act_fn)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, torch_vllm_outplace_fused_experts)
|
||||
from vllm.model_executor.layers.fused_moe import (activation_without_mul,
|
||||
fused_topk)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -230,7 +230,7 @@ class NomicMoE(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.total_intermediate_size = intermediate_size
|
||||
self.intermediate_size = divide(intermediate_size, self.tp_size)
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_act = activation_without_mul(hidden_act)
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -297,14 +297,14 @@ class NomicMoE(nn.Module):
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=False)
|
||||
final_hidden_states = torch_vllm_outplace_fused_experts(
|
||||
|
||||
final_hidden_states = torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=self.w1,
|
||||
w2=self.w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=self.hidden_act,
|
||||
is_act_and_mul=False,
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
|
||||
@ -37,7 +37,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -163,13 +163,19 @@ class DeepseekMoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.config.norm_topk_prob,
|
||||
inplace=True)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.config.norm_topk_prob)
|
||||
|
||||
final_hidden_states = fused_experts(hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True)
|
||||
|
||||
if self.config.n_shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -136,13 +136,18 @@ class MiniCPMMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True,
|
||||
inplace=True)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=True)
|
||||
|
||||
final_hidden_states = fused_experts(hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
|
||||
@ -702,4 +702,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
return self.model.get_expert_mapping()
|
||||
@ -81,9 +81,14 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
|
||||
|
||||
def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool:
|
||||
if not (isinstance(module, FusedMoE)
|
||||
and module.moe_config.quant_dtype == torch.float8_e4m3fn
|
||||
and module.moe_config.block_shape == deep_gemm_block_shape()):
|
||||
if not isinstance(module, FusedMoE):
|
||||
return False
|
||||
|
||||
moe_quant_config = module.quant_method.get_fused_moe_quant_config(module)
|
||||
|
||||
if (moe_quant_config is None
|
||||
or moe_quant_config.quant_dtype != torch.float8_e4m3fn
|
||||
or moe_quant_config.block_shape != deep_gemm_block_shape()):
|
||||
return False
|
||||
|
||||
if not isinstance(module.quant_method.fused_experts,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user