[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm 2025-08-15 14:46:00 -04:00 committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 2010 additions and 1293 deletions

View File

@ -399,6 +399,7 @@ steps:
- label: Kernels MoE Test %N
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/quantization/cutlass_w8a8/moe/
- csrc/moe/
- tests/kernels/moe
- vllm/model_executor/layers/fused_moe/

View File

@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
### FusedMoEModularKernel Initialization
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
* maybe_make_prepare_finalize,
* select_gemm_impl, and
* init_prepare_finalize
#### maybe_make_prepare_finalize
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
Please refer to the implementations in,
* `ModelOptNvFp4FusedMoE`
#### select_gemm_impl
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.

View File

@ -70,12 +70,27 @@ def parse_args():
default=64,
help=("Maximum number of sequences to be processed in a single iteration."),
)
parser.add_argument(
"--max-model-len",
type=int,
help=("Maximum number of tokens to be processed in a single iteration."),
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help=("Number of seconds before unresponsive process is killed."),
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
)
parser.add_argument(
"--quantization",
type=str,
)
return parser.parse_args()
@ -90,7 +105,9 @@ def main(
enforce_eager,
trust_remote_code,
max_num_seqs,
max_model_len,
gpu_memory_utilization,
quantization,
):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
@ -142,7 +159,9 @@ def main(
enable_expert_parallel=True,
trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
quantization=quantization,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
@ -198,14 +217,16 @@ if __name__ == "__main__":
args.enforce_eager,
args.trust_remote_code,
args.max_num_seqs,
args.max_model_len,
args.gpu_memory_utilization,
args.quantization,
),
)
proc.start()
procs.append(proc)
exit_code = 0
for proc in procs:
proc.join(timeout=300)
proc.join(timeout=args.timeout)
if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill()

View File

@ -7,41 +7,22 @@ import torch
import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
make_quant_fp8_weights, per_token_cast_to_fp8)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
@ -69,17 +50,24 @@ class Config:
torch_trace_dir_path: Optional[str] = None
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
def describe(self) -> str:
s = ""
s += "== Config:\n"
s += f" world_size={self.world_size}\n"
s += f" PF={self.prepare_finalize_type.__name__}\n"
s += f" FE={self.fused_experts_type.__name__}\n"
s += f" E={self.E}\n"
s += f" Ms={self.Ms}\n"
s += f" N={self.N}\n"
s += f" K={self.K}\n"
s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype}\n"
s += f" q_block_shape={self.quant_block_shape}\n"
@ -95,34 +83,28 @@ class Config:
return self.Ms
@property
def quant_dtype(self) -> Optional[torch.dtype]:
if self.quant_config is None:
return None
def quant_dtype(self) -> Union[torch.dtype, str, None]:
assert self.quant_config is not None
return self.quant_config.quant_dtype
@property
def is_per_act_token_quant(self) -> bool:
if self.quant_config is None:
return False
assert self.quant_config is not None
return self.quant_config.per_act_token_quant
@property
def is_per_tensor_act_quant(self) -> bool:
if self.quant_config is None:
return False
return (not self.is_per_act_token_quant
and self.quant_block_shape is None)
@property
def is_per_out_ch_quant(self) -> bool:
if self.quant_config is None:
return False
assert self.quant_config is not None
return self.quant_config.per_out_ch_quant
@property
def quant_block_shape(self) -> Optional[list[int]]:
if self.quant_config is None:
return None
assert self.quant_config is not None
return self.quant_config.block_shape
@property
@ -130,17 +112,6 @@ class Config:
assert isinstance(self.topks, int)
return self.topks
@property
def topk_ids_dtype(self) -> Optional[torch.dtype]:
topk_ids_dtype = None
if self.prepare_finalize_type == PplxPrepareAndFinalize:
topk_ids_dtype = torch.uint32
elif self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]:
topk_ids_dtype = torch.int64
return topk_ids_dtype
@property
def num_local_experts(self) -> int:
return self.E // self.world_size
@ -154,12 +125,17 @@ class Config:
vllm_config.parallel_config.enable_expert_parallel = True
env_dict = {
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
}
backend = self.all2all_backend()
if backend is not None:
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
if self.fused_moe_chunk_size is not None:
env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
return vllm_config, env_dict
def is_fp8_block_quantized(self):
@ -167,85 +143,59 @@ class Config:
and self.quant_block_shape is not None)
def is_batched_prepare_finalize(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
info = prepare_finalize_info(self.prepare_finalize_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
def is_batched_fused_experts(self):
return self.fused_experts_type in [
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
]
info = expert_info(self.fused_experts_type)
return (mk.FusedMoEActivationFormat.BatchedExperts ==
info.activation_format)
def is_standard_fused_experts(self):
return self.fused_experts_type in [
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts
]
info = expert_info(self.fused_experts_type)
return mk.FusedMoEActivationFormat.Standard == info.activation_format
def is_fe_16bit_supported(self):
return self.fused_experts_type in [
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
NaiveBatchedExperts, TritonExperts
]
def fe_supported_types(self):
info = expert_info(self.fused_experts_type)
return info.supported_dtypes
def is_fe_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
NaiveBatchedExperts,
]
def pf_supported_types(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supported_dtypes
def is_fe_block_fp8_supported(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
BatchedTritonOrDeepGemmExperts,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
]
def is_block_quant_supported(self):
info = expert_info(self.fused_experts_type)
return info.blocked_quantization_support
def is_fe_supports_chunking(self):
return self.fused_experts_type in [
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
TritonExperts
]
info = expert_info(self.fused_experts_type)
return info.supports_chunking
def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
def supports_apply_weight_on_input(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supports_apply_weight_on_input
def needs_deep_gemm(self):
return self.fused_experts_type in [
BatchedDeepGemmExperts,
DeepGemmExperts,
]
info = expert_info(self.fused_experts_type)
return info.needs_deep_gemm
def needs_pplx(self):
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend == "pplx"
def needs_deep_ep(self):
return self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
info = prepare_finalize_info(self.prepare_finalize_type)
return (info.backend == "deepep_high_throughput"
or info.backend == "deepep_low_latency")
def all2all_backend(self):
if self.needs_pplx():
return "pplx"
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
return "deepep_high_throughput"
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
return "deepep_low_latency"
return "naive"
def needs_all2all(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize
]
info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend
def is_valid(self):
# Check prepare-finalize and fused-experts compatibility
@ -267,28 +217,28 @@ class Config:
# invalid quant config
return False
# check bf16 / fp16 support
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
if is_16bit and not self.is_fe_16bit_supported():
# check type support
if self.quant_dtype is None:
if (self.dtype not in self.pf_supported_types()
or self.dtype not in self.fe_supported_types()):
return False
else:
if (self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()):
return False
# Check fp8 support
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
if is_fp8 and not self.is_fe_fp8_supported():
return False
# Check fp8 block quanization support
# Check block quanization support
is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and not is_fp8:
if is_block_quatized and self.quant_dtype is None:
return False
if is_block_quatized and not self.is_fe_block_fp8_supported():
if is_block_quatized and not self.is_block_quant_supported():
return False
# deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized:
return False
# Check dependencies
# Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep():
return False
if self.needs_deep_gemm() and not has_deep_gemm():
@ -305,6 +255,8 @@ class WeightTensors:
w2: torch.Tensor
w1_scale: Optional[torch.Tensor]
w2_scale: Optional[torch.Tensor]
w1_gs: Optional[torch.Tensor] = None
w2_gs: Optional[torch.Tensor] = None
def describe(self):
s = ""
@ -313,13 +265,20 @@ class WeightTensors:
s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
return s
def is_quantized(self) -> bool:
# or w1_scale is not None?
return (self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device())
is_quantized = self.w1.dtype == torch.float8_e4m3fn
if is_quantized:
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to(
@ -327,56 +286,51 @@ class WeightTensors:
self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device())
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts
e = s + num_local_experts
w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :]
is_quantized = self.w1.dtype == torch.float8_e4m3fn
w1_scale, w2_scale = (None, None)
if is_quantized:
if self.is_quantized():
assert self.w1_scale is not None
assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :]
return WeightTensors(w1, w2, w1_scale, w2_scale)
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@staticmethod
def make(config: Config) -> "WeightTensors":
if config.quant_dtype is None:
# just make normal dtype weights
w1, w2 = make_non_quant_weights(e=config.E,
n=config.N,
k=config.K,
dtype=config.dtype)
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
assert config.quant_dtype == torch.float8_e4m3fn
if not config.is_fp8_block_quantized():
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
e=config.E,
n=config.N,
k=config.K,
per_out_channel_quant=config.is_per_out_ch_quant,
in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
assert config.quant_block_shape is not None
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
e=config.E,
n=config.N,
k=config.K,
block_size=config.quant_block_shape,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
w2_scale=w2_scale,
w1_gs=w1_gs,
w2_gs=w2_gs)
@dataclass
@ -449,7 +403,6 @@ class RankTensors:
dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False)
topk_ids = topk_ids.to(config.topk_ids_dtype)
# distribute topk_ids evenly
for mi in range(m):
@ -457,7 +410,7 @@ class RankTensors:
topk_ids = topk_ids.to(device=torch.cuda.current_device())
expert_map = None
if config.world_size > 1:
if config.world_size > 1 and config.supports_expert_map():
expert_map = torch.full((global_num_experts, ),
fill_value=-1,
dtype=torch.int32)
@ -480,92 +433,100 @@ class RankTensors:
def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor:
return torch_experts(a=rank_tensors.hidden_states,
w1=weights.w1,
w2=weights.w2,
if config.quant_dtype == "nvfp4":
quant_blocksize = 16
dtype = config.dtype
w1_q = weights.w1
w1_blockscale = weights.w1_scale
w1_gs = weights.w1_gs
w2_q = weights.w2
w2_blockscale = weights.w2_scale
w2_gs = weights.w2_gs
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
assert w1_blockscale.shape[1] % 128 == 0
assert w1_blockscale.shape[2] % 4 == 0
assert w2_blockscale.shape[1] % 128 == 0
assert w2_blockscale.shape[2] % 4 == 0
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
rank_tensors.hidden_states, a_global_scale)
a = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize)
e = w1_q.shape[0]
n = w1_q.shape[1] // 2
k = w2_q.shape[1]
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
a_scale = None
w1_scale = None
w2_scale = None
quant_dtype = None
per_act_token_quant = False
block_shape = None
else:
a = rank_tensors.hidden_states
a_scale = rank_tensors.hidden_states_scale
w1 = weights.w1
w1_scale = weights.w1_scale
w2 = weights.w2
w2_scale = weights.w2_scale
quant_dtype = config.quant_dtype
per_act_token_quant = config.is_per_act_token_quant
block_shape = config.quant_block_shape
return torch_experts(a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E,
expert_map=None,
w1_scale=weights.w1_scale,
w2_scale=weights.w2_scale,
a1_scale=rank_tensors.hidden_states_scale,
quant_dtype=config.quant_dtype,
per_act_token_quant=config.is_per_act_token_quant,
block_shape=config.quant_block_shape,
apply_router_weights_on_input=config.topk == 1)
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_scale,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input())
def make_fused_experts(
config: Config, moe: FusedMoEConfig,
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if config.fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif config.fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif config.fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif config.fused_experts_type == CutlassExpertsFp8:
use_batched_format = config.is_batched_prepare_finalize()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
kwargs = {
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": config.is_per_act_token_quant,
"per_out_ch_quant": config.is_per_out_ch_quant,
"block_shape": config.quant_block_shape,
"num_dispatchers": num_dispatchers,
"use_batched_format": use_batched_format
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
return experts
def make_modular_kernel(config: Config,
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
def make_modular_kernel(
config: Config,
vllm_config: VllmConfig,
weights: WeightTensors,
) -> mk.FusedMoEModularKernel:
def next_power_of_2(x):
import math
@ -579,6 +540,7 @@ def make_modular_kernel(config: Config,
dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config,
)
moe = FusedMoEConfig(
num_experts=config.E,
experts_per_token=config.topk,
@ -591,15 +553,16 @@ def make_modular_kernel(config: Config,
)
# make modular kernel
prepare_finalize = None
if config.needs_all2all():
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
else:
prepare_finalize = MoEPrepareAndFinalizeNoEP()
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
config.all2all_backend(), moe)
fused_experts = make_fused_experts(config, moe,
prepare_finalize.num_dispatchers())
fused_experts = make_fused_experts(
config.fused_experts_type,
moe,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
)
modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
@ -620,22 +583,45 @@ def run_modular_kernel(
# weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config)
mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = {
"hidden_states": rank_tensors.hidden_states.clone(
"hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place
"w1": rank_weights.w1,
"w2": rank_weights.w2,
"topk_weights": rank_tensors.topk_weights,
"topk_ids": rank_tensors.topk_ids,
"expert_map": rank_tensors.expert_map,
"w1_scale": rank_weights.w1_scale,
"w2_scale": rank_weights.w2_scale,
"a1_scale": rank_tensors.hidden_states_scale,
"global_num_experts": config.E,
"apply_router_weight_on_input": config.topk == 1,
"w1":
rank_weights.w1,
"w2":
rank_weights.w2,
"topk_weights":
rank_tensors.topk_weights,
"topk_ids":
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
"expert_map":
rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":
config.topk == 1 and config.supports_apply_weight_on_input(),
}
num_tokens = rank_tensors.hidden_states.shape[0]
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
device="cuda",
dtype=torch.int)
with set_forward_context(
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)
return out

View File

@ -1,58 +1,316 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch
# Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_pplx
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if has_deep_ep():
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
backend: Optional[str]
supports_apply_weight_on_input: bool = True
@dataclass
class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
def register_prepare_and_finalize(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
backend: Optional[str],
force_multigpu: bool = False,
supports_apply_weight_on_input: bool = True,
):
global PREPARE_FINALIZE_INFO
global MK_ALL_PREPARE_FINALIZE_TYPES
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
assert kind not in PREPARE_FINALIZE_INFO
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
backend,
supports_apply_weight_on_input,
)
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
if backend is not None or force_multigpu:
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
else:
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
def register_experts(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
):
global EXPERT_INFO
global MK_FUSED_EXPERT_TYPES
assert kind not in EXPERT_INFO
EXPERT_INFO[kind] = ExpertInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_chunking,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
)
MK_FUSED_EXPERT_TYPES.append(kind)
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
info = PREPARE_FINALIZE_INFO.get(kind)
assert info is not None
return info
def expert_info(kind) -> ExpertInfo:
info = EXPERT_INFO.get(kind)
assert info is not None
return info
register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend=None,
)
register_experts(
BatchedTritonExperts,
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
)
register_experts(
TritonExperts,
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
)
register_experts(
NaiveBatchedExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True,
)
# Disable on blackwell for now
if has_deep_ep() and not current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_high_throughput",
)
register_prepare_and_finalize(
DeepEPLLPrepareAndFinalize,
batched_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_low_latency",
)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
backend="pplx",
)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
if has_pplx():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
if has_deep_ep():
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]
if (has_flashinfer_cutlass_fused_moe()
and current_platform.has_device_capability(100)):
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
supports_apply_weight_on_input=False,
)
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
register_experts(
FlashInferExperts,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
else:
FlashInferCutlassMoEPrepareAndFinalize = None
MK_FUSED_EXPERT_TYPES = [
if has_deep_gemm() and is_deep_gemm_supported():
register_experts(
BatchedDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts,
TritonExperts,
]
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
)
if cutlass_fp8_supported():
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
CutlassExpertsFp8)
register_experts(
CutlassExpertsFp8,
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False,
)
register_experts(
CutlassBatchedExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False,
)
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4)
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [
None,
@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [
# block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations
]
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
return t[s:e]
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"num_dispatchers": num_dispatchers,
}
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
return experts

View File

@ -52,7 +52,7 @@ def profile_modular_kernel(
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
# make modular kernel
mk = make_modular_kernel(config, vllm_config)
mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = {
"hidden_states": rank_tensors.hidden_states,
@ -83,7 +83,7 @@ def rank_worker(
# sanity check
from vllm import envs
if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device
weights.to_current_device()

View File

@ -1,117 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm._custom_ops as ops
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_non_quant_weights(
e: int,
n: int,
k: int,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2
"""
device = torch.cuda.current_device()
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
return w1, w2
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
dtype = torch.bfloat16
device = torch.cuda.current_device()
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device=device,
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device=device,
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
(k + (block_k - 1)) // block_k)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=[block_k, block_n])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=[block_k, block_n])
return w1, w2, w1_s, w2_s
def make_quant_fp8_weights(
e: int,
n: int,
k: int,
per_out_channel_quant: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return w1, w2, w1_scale, w2_scale
"""
q_dtype = torch.float8_e4m3fn
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
# w1 -> w1_q, w2 -> w2_q
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
n_b_scales = 2 * n if per_out_channel_quant else 1
k_b_scales = k if per_out_channel_quant else 1
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
return w1_q, w2_q, w1_scale, w2_scale

View File

@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
per_act_token_quant=per_act_token_quant,
)
B, B_q, B_scale, _, _, _ = make_test_weights(
(B, B_q, B_scale, _), _ = make_test_weights(
num_experts,
N // 2,
K,
@ -243,7 +243,7 @@ def test_fused_moe_batched_experts(
act_dtype = dtype
quant_dtype = None
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
(w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
e,
n,
k,

View File

@ -161,7 +161,8 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,
@ -173,6 +174,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False,
block_shape=block_size)
@ -247,7 +249,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,

View File

@ -118,7 +118,8 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
(_, w1, w1_s, _), (_, w2, w2_s,
_) = make_test_weights(E,
N,
K,
dtype,

View File

@ -9,6 +9,7 @@ import random
import pytest
import torch
from tests.kernels.moe.utils import per_token_cast_to_fp8
from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
@ -16,20 +17,6 @@ from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view *
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096),
(4, 8192, 2048, 7168),
@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm(
device=device,
dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]

View File

@ -70,8 +70,10 @@ def make_block_quant_fp8_weights(
"""
Return weights w1q, w2q, w1_scale, w2_scale
"""
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
_) = make_test_weights(e, n, k, torch.bfloat16,
torch.float8_e4m3fn,
block_size)
return w1q, w2q, w1_scale, w2_scale

View File

@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
# Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path.
MNKs = [
(1024, 512, 128),
(1024, 512, 512),
(2048, 512, 512),
(1024, 768, 128),
(1024, 768, 512),
(2048, 768, 512),
(512, 1024, 1024),
(512, 2048, 2048),
(4096, 4096, 1024),

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
a1_gscale=a1_gs,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
flashinfer_output,
atol=1e-1,
rtol=1e-1)
if __name__ == "__main__":
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import textwrap
import traceback
from itertools import product
from typing import Optional
@ -10,41 +12,51 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl,
run_modular_kernel)
from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config)
# TODO (varun): These requirements are very strict and could be relaxed.
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
or has_flashinfer_cutlass_fused_moe())
meets_package_requirements = pytest.mark.skipif(
not has_all_packages,
reason="Requires deep_ep & deep_gemm & pplx packages",
meets_multi_gpu_requirements = pytest.mark.skipif(
not has_any_multi_gpu_package,
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
)
def format_result(verbose, msg, ex=None):
if ex is not None:
x = str(ex)
newx = x.strip(" \n\t")[:16]
if len(newx) < len(x):
newx = newx + " ..."
prefix = "E\t"
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
print(f"FAILED {msg} - {newx}\n")
elif verbose:
print(f"PASSED {msg}")
else:
print(".", end="")
def rank_worker(
pgi: ProcessGroupInfo,
vllm_config: VllmConfig,
cpu_group,
config: Config,
weights: WeightTensors,
verbose: bool,
):
current_platform.seed_everything(pgi.rank)
@ -61,8 +73,13 @@ def rank_worker(
TOPKs = config.topks
assert isinstance(TOPKs, list)
exceptions = []
count = 0
for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...")
try:
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
count = count + 1
# override m and topk
cfgx = copy.deepcopy(config)
cfgx.Ms = m
@ -78,22 +95,42 @@ def rank_worker(
with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
if config.quant_dtype == "nvfp4":
atol = 1e-1
rtol = 1e-1
else:
atol = 3e-2
rtol = 3e-2
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
format_result(verbose, config.describe())
except Exception as ex:
format_result(verbose, config.describe(), ex)
exceptions.append(ex)
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
def run(config: Config):
def run(config: Config, verbose: bool):
assert config.is_valid()
print(f"Testing config \n{config.describe()} ...")
weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights)
env_dict, config, weights, verbose)
Ms = [32, 64]
Ks = [7168] # hidden sizes
# hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
Ns = [2048]
TOPKs = [4, 1]
Es = [32]
@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
def is_nyi_config(config: Config) -> bool:
# We know these configs to be legitimate. but still fail.
info = expert_info(config.fused_experts_type)
if (config.fused_experts_type in [
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
TritonExperts, TritonOrDeepGemmExperts
]):
if info.needs_matching_quant:
# The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1)
return unsupported_quant_config
# cutlass kernels dont support expert_maps yet.
return config.fused_experts_type == CutlassExpertsFp8
return not info.supports_expert_map
@pytest.mark.parametrize("k", Ks)
@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool:
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2])
@meets_package_requirements
@meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int):
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config(
Ms=Ms,
@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu(
fused_moe_chunk_size=fused_moe_chunk_size,
world_size=world_size,
)
if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
print(f"{config.describe()}")
run(config)
verbosity = pytestconfig.getoption('verbose')
run(config, verbosity > 0)
@pytest.mark.parametrize("k", Ks)
@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu(
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1])
@meets_package_requirements
def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int):
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config(
Ms=Ms,
K=k,
@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu(
if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
run(config)
verbosity = pytestconfig.getoption('verbose')
run(config, verbosity > 0)
if __name__ == '__main__':
@ -211,4 +246,4 @@ if __name__ == '__main__':
args = parser.parse_args()
config = make_config(args)
run(config)
run(config, True)

View File

@ -3,6 +3,7 @@
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n),
device="cuda",
dtype=torch.float8_e4m3fn)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1_q = torch.empty((e, 2 * n, k // 2),
device="cuda",
dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2[expert], w2_gs[expert])
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
n=n,
k=k,
e=e,
device=a.device,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)

View File

@ -9,7 +9,8 @@ import torch
from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
@ -123,12 +124,8 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers)
experts = CutlassExpertsFp8(num_local_experts,
out_dtype,
per_act_token,
per_out_ch,
num_dispatchers=num_dispatchers,
use_batched_format=True)
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, per_act_token, per_out_ch)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,

View File

@ -770,7 +770,7 @@ def test_pplx_moe_slow(
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights(
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e,
n,
k,
@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
args = dict()
if make_weights:
_, w1, w1_s, _, w2, w2_s = make_test_weights(
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e,
n,
k,

View File

@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
@ -169,28 +171,41 @@ def make_quantized_test_activations(
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
quant_dtype: Union[torch.dtype, str, None],
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
w_gs = None
if block_shape is not None:
assert not per_token_quant
if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape)
else:
elif quant_dtype == torch.float8_e4m3fn:
w, w_s = per_block_cast_to_fp8(w, block_shape)
elif quant_dtype == "nvfp4":
raise RuntimeError("blocked quantization not supported for nvfp4")
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
else:
elif quant_dtype == torch.float8_e4m3fn:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
elif quant_dtype == "nvfp4":
assert not per_token_quant
w_amax = torch.abs(w).max().to(torch.float32)
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
w, w_s = ops.scaled_fp4_quant(w, w_gs)
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
return w, w_s
return w, w_s, w_gs
def make_test_weight(
@ -198,21 +213,26 @@ def make_test_weight(
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
if quant_dtype is not None:
w_l = [None] * e
w_s_l = [None] * e
w_gs_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights(
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
if e > 0 and w_gs_l[0] is not None:
w_gs = torch.stack(w_gs_l)
if w_s.ndim == 2:
assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1)
@ -225,8 +245,9 @@ def make_test_weight(
else:
w = w_16
w_s = None
w_gs = None
return w_16, w, w_s
return w_16, w, w_s, w_gs
def make_test_weights(
@ -234,14 +255,30 @@ def make_test_weights(
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
)
def per_token_cast_to_fp8(
x: torch.Tensor,
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)

View File

@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep
self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
"""
Prepare the communication buffer for the model.
"""
if not self.use_all2all:
if not self.is_ep_communicator:
return
moe_modules = [
@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config)
module.quant_method.init_prepare_finalize()
def dispatch(
self, hidden_states: torch.Tensor,

View File

@ -49,7 +49,8 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4,
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
@ -69,6 +70,7 @@ if HAS_TRITON:
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",
"TritonExperts",
"BatchedTritonExperts",
"DeepGemmExperts",

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
@ -254,18 +254,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
apply_router_weight_on_input, extra_expert_args)
apply_router_weight_on_input)

View File

@ -45,7 +45,6 @@ def get_quant_config_weight_quant(
return _get_quant_config_quantization_args(quant_config, "weights")
# TODO (bnell): use scalar_type instead of bools?
def get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
@ -65,7 +64,8 @@ def get_config_quant_dtype(
@dataclass
class FusedMoEQuantConfig:
# The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None
# TODO (bnell): use scalar_type instead of Union.
quant_dtype: Union[torch.dtype, str, None] = None
per_act_token_quant: bool = False
per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None
@ -141,6 +141,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_mxfp4_w4a4,
]
]) <= 1, "Quantization flags are mutually exclusive."
@ -334,7 +335,7 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0
@property
def quant_dtype(self) -> Optional[torch.dtype]:
def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is not None:
return self.quant_config.quant_dtype
else:
@ -429,7 +430,7 @@ class FusedMoEConfig:
block_shape = None
per_act_token_quant = False
per_out_ch_quant = False
quant_dtype: Optional[torch.dtype] = None
quant_dtype: Union[torch.dtype, str, None] = None
input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config)
@ -453,7 +454,7 @@ class FusedMoEConfig:
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = torch.uint8
quant_dtype = "nvfp4"
if weight_quant is not None:
per_out_ch_quant = (

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels."""
from typing import Any, Callable, Optional
from typing import Callable, Optional
import torch
@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache,
extract_required_args)
_resize_cache)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
@ -213,19 +212,14 @@ def run_cutlass_moe_fp8(
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched?
# maybe remove need for passing aq to workspace_shapes
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
@ -234,33 +228,84 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert max_experts_per_worker > 0
assert not use_batched_format or num_dispatchers is not None
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = self.activation_formats[
0] == mk.FusedMoEActivationFormat.BatchedExperts
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return not self.use_batched_format
return True
def supports_expert_map(self) -> bool:
return not self.use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return True
def workspace_shapes(
self,
@ -274,54 +319,69 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
(N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def cutlass_moe_fp8(
@ -387,11 +447,9 @@ def cutlass_moe_fp8(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
max_experts_per_worker=num_experts,
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
use_batched_format=False,
),
)
@ -476,8 +534,9 @@ def run_cutlass_moe_fp4(
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
" between weights.")
assert (e_w1 == e_w2
and e_w1 == e), ("Number of experts must match",
f" between weights. {e_w1}, {e_w2}, {e}")
assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token_quant: bool,
@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
use_batched_format: bool = False,
):
super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property
def activation_formats(
self
@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: torch.Tensor,
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
required_keys = [
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
"e", "device"
]
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
device) = extract_required_args(extra_expert_args, required_keys)
):
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_fp4(
output=output,
a=hidden_states,
a1_gscale=a1_gscale,
a1_gscale=self.a1_gscale,
w1_fp4=w1,
w1_blockscale=w1_scale,
w1_alphas=g1_alphas,
a2_gscale=a2_gscale,
w1_alphas=self.g1_alphas,
a2_gscale=self.a2_gscale,
w2_fp4=w2,
w2_blockscale=w2_scale,
w2_alphas=g2_alphas,
w2_alphas=self.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace13=workspace13,
@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
n=n,
k=k,
e=e,
device=device,
device=hidden_states.device,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -677,7 +757,6 @@ def cutlass_moe_fp4(
n: int,
k: int,
e: int,
device: torch.device,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map "
@ -686,6 +765,10 @@ def cutlass_moe_fp4(
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e,
out_dtype=a.dtype,
per_act_token_quant=False,
@ -693,29 +776,7 @@ def cutlass_moe_fp4(
use_batched_format=False,
),
)
extra_expert_args = {
'g1_alphas': g1_alphas,
'g2_alphas': g2_alphas,
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
'm': m,
'n': n,
'k': k,
'e': e,
'device': device,
}
# NVFP4 requires two levels of quantization, which involves computing some
# scaling factors dynamically. This makes it incompatible with the typical
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
# MoE body.
extra_prepare_args = {
'skip_quant': True,
}
# Similar reason as above.
extra_finalize_args = {
'skip_weight_reduce': True,
}
return fn(
hidden_states=a,
w1=w1_fp4,
@ -731,9 +792,6 @@ def cutlass_moe_fp4(
a1_scale=None,
a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts(
k = w1_q.size(1)
n = w2_q.size(1)
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device="cuda")
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(a,
@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts(
block_shape=[128, 128])
device = a_q.device
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import Any, Optional
from typing import Optional
import torch
from tqdm import tqdm
@ -230,7 +230,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
assert self.block_shape is not None
assert a1q_scale is not None

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import deep_ep
import torch
@ -127,12 +127,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights)
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -187,11 +191,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
assert self.handle is not None

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from typing import Optional, Union
import deep_ep
import torch
@ -77,7 +77,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -111,12 +111,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -162,11 +166,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional, Union
import torch
@ -8,8 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
TopKWeightAndReduceNoOP)
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe)
@ -43,31 +42,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_nvfp4_w4a4: bool = False,
use_fp8_w8a8: bool = False,
use_dp: bool = False,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
ep_rank: int = 0,
ep_size: int = 1,
tp_rank: int = 0,
tp_size: int = 1,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
):
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.uint8,
quant_dtype=quant_dtype,
per_act_token_quant=False,
block_shape=None,
))
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
self.use_fp8_w8a8 = use_fp8_w8a8
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
"currently supported.")
self.ep_rank = ep_rank
self.ep_size = ep_size
self.tp_rank = tp_rank
self.tp_size = tp_size
self.use_dp = use_dp
assert not use_batched_format or num_dispatchers is not None
self.num_dispatchers = num_dispatchers
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype
@property
def activation_formats(
@ -84,8 +86,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
@ -117,8 +118,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
aq_m, aq_n = aq.shape
workspace2 = ()
output_shape = (aq_m, aq_n * 2)
@ -149,21 +148,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool],
extra_expert_args: Optional[dict[str, Any]],
):
assert extra_expert_args is not None, \
"extra_expert_args must be provided"
required_keys = [
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
]
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
extract_required_args(extra_expert_args, required_keys))
# Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale.
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
# Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, (
@ -171,12 +158,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"be None for FlashInferExperts")
quant_scales = [
a1_gscale,
self.a1_gscale,
w1_scale.view(torch.int32),
g1_alphas,
a2_gscale,
self.g1_alphas,
self.a2_gscale,
w2_scale.view(torch.int32),
g2_alphas,
self.g2_alphas,
]
_ = flashinfer_cutlass_fused_moe(
input=hidden_states,
@ -185,7 +172,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights=w1.view(torch.long),
fc2_expert_weights=w2.view(torch.long),
output_dtype=out_dtype,
output_dtype=self.out_dtype,
quant_scales=quant_scales,
input_sf=a1q_scale,
tp_size=self.tp_size,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input)
moe_kernel_quantize_input)
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
@ -21,16 +21,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_dp: bool,
a1_gscale: Optional[torch.Tensor],
num_dispatchers: int = 1,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
@ -55,10 +54,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
@ -67,22 +67,22 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
(a1_gscale, use_dp, local_tokens) = extract_required_args(
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_gscale,
self.a1_gscale,
quant_config.quant_dtype,
self.per_channel_quant,
self.block_shape,
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
quant_config.per_act_token_quant,
quant_config.block_shape,
# Swizzling after communication
is_fp4_scale_swizzled=not self.use_dp,
)
if use_dp:
if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
get_dp_group().all_gatherv(
[topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes())
sizes=get_local_sizes(),
)
a1_m, a1_n = a1q.shape
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
@ -91,13 +91,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
(use_dp,
local_tokens) = extract_required_args(extra_finalize_args,
['use_dp', 'local_tokens'])
if use_dp:
if self.use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output, dim=0, sizes=get_local_sizes())
output.copy_(fused_expert_output)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel."""
from typing import Any, Optional
from typing import Optional
import torch
@ -496,12 +496,16 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -590,11 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply(
@ -688,18 +696,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else:
return t.to(f32) * group_broadcast(scale, t.shape)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
assert hidden_states.dim() == 3
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
@ -894,18 +912,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (

View File

@ -1394,9 +1394,9 @@ def fused_experts(hidden_states: torch.Tensor,
# E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue.
should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
) or _valid_deep_gemm(hidden_states, w1, w2)
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
if (allow_deep_gemm and use_fp8_w8a8
and (is_blackwell_deep_gemm_e8m0_used()
or _valid_deep_gemm(hidden_states, w1, w2))):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
@ -1905,7 +1905,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
# Check constraints.
if self.use_int4_w4a16:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
import torch
@ -8,7 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils import has_triton_kernels
logger = init_logger(__name__)
@ -160,12 +159,16 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers: int,
w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
):
super().__init__(quant_config)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision
self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property
def activation_formats(
@ -219,11 +222,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
["w1_bias", "w2_bias"]))
return triton_kernel_fused_experts(
output,
hidden_states,
@ -240,8 +239,8 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_bias=self.w1_bias,
w2_bias=self.w2_bias,
w1_precision=self.w1_precision,
w2_precision=self.w2_precision,
a1_scale=a1q_scale,

View File

@ -37,7 +37,6 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
@ -49,9 +48,6 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize)
if has_flashinfer():
from .flashinfer_cutlass_prepare_finalize import (
FlashInferCutlassMoEPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -80,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase):
moe: FusedMoEConfig
# TODO(bnell): also pass quant_config?
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.topk_indices_dtype = None
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -99,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return False
@staticmethod
def maybe_make_prepare_finalize(
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
def _maybe_make_prepare_finalize(
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_flashinfer_cutlass_kernels:
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
quant_dtype=moe.quant_dtype, )
assert not moe.use_flashinfer_cutlass_kernels, \
"Must be created in modelopt.py"
if moe.use_pplx_kernels:
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
@ -188,14 +189,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize
def init_prepare_finalize(self, moe: FusedMoEConfig):
self.moe = moe
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(
self.moe)
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[FusedMoEPrepareAndFinalize]:
if moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
else:
return None
def init_prepare_finalize(self):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
self.topk_indices_dtype = None
if prepare_finalize is not None:
logger.debug("%s", prepare_finalize.__class__.__name__)
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
self, id(self))
assert self.topk_indices_dtype is None
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe)
self.fused_experts = FusedMoEModularKernel(
@ -214,12 +226,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
pass
@abstractmethod
def apply(
self,
@ -251,10 +257,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled:
@ -266,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
@ -474,9 +478,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
# add w1_bias/w2_bias to kwargs if they exist
kwargs = dict(
elif self.fused_experts is not None:
if self.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@ -488,17 +494,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=global_num_experts,
expert_map=expert_map,
)
if isinstance(self.fused_experts,
FusedMoEModularKernel) and self.has_bias:
raise ValueError(
"FusedMoEModularKernel does not support bias.")
if self.has_bias:
kwargs.update({
"w1_bias": getattr(layer, "w13_bias", None),
"w2_bias": getattr(layer, "w2_bias", None),
})
return self.fused_experts(**kwargs)
else:
assert fused_experts is not None
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_bias=layer.w13_bias if self.has_bias else None,
w2_bias=layer.w2_bias if self.has_bias else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def forward_cpu(
self,
@ -868,8 +879,6 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
if isinstance(self.quant_method, FusedMoEMethodBase):
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
# Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Any, Optional, final
from typing import Optional, final
import torch
@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError
@abstractmethod
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
"""
Perform any combine plus apply weights and perform a reduction on the
fused experts output.
@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
"""
This function computes the intermediate result of a Mixture of Experts
@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}")
def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int,
self,
fused_out: Optional[torch.Tensor],
a1: torch.Tensor,
a1q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
local_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
return fused_out
@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
# TODO(bnell): get rid of one level here, update slice functions
# to nops on num_chunks==1
if not self.fused_experts.supports_chunking() or num_chunks == 1:
return self._do_fused_experts(
fused_out=None,
@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
# Chunking required case
assert num_chunks > 1
@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx))
@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx),
a1=a1,
@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args)
)
return fused_out
@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
extra_prepare_args,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args)
)
self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids,
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args)
)
return output

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional, Union
import pplx_kernels as pplx
import torch
@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes(
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
@ -89,12 +90,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_
def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int,
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -213,11 +218,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
) -> None:
assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
@ -38,7 +38,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
@ -50,27 +49,21 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
if (extra_prepare_args is not None
and extra_prepare_args.get("skip_quant", True)):
# Skip quantization if explicitly requested
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
extra_finalize_args: Optional[dict[str, Any]]) -> None:
if (extra_finalize_args is not None
and extra_finalize_args.get("skip_weight_reduce", True)):
assert output.shape == fused_expert_output.shape
output.copy_(fused_expert_output)
else:
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl.apply(

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Optional
import torch
@ -119,18 +119,28 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts,
expert_tokens_meta)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
):
use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_e8m0_used()))
@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
extra_expert_args,
)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod
from typing import Any, Optional, Union
from typing import Optional, Union
import torch
@ -189,7 +189,7 @@ def moe_kernel_quantize_input(
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.uint8: # nvfp4
elif quant_dtype == "nvfp4":
return _fp4_quantize(A,
A_scale,
is_sf_swizzled_layout=is_fp4_scale_swizzled)
@ -252,17 +252,3 @@ def _validate_scale_shape(
assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def extract_required_args(
extra_args: Optional[dict[str, Any]],
required_keys: list[str],
) -> tuple[Any, ...]:
if extra_args is None:
raise ValueError("`extra_args` must be provided.")
missing_keys = [k for k in required_keys if k not in extra_args]
if missing_keys:
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
return tuple(extra_args[k] for k in required_keys)

View File

@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin)
return AWQMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig):
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
return GPTQMarlinMoEMethod(quant_args_marlin)
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:

View File

@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig):
}
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config)
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return None

View File

@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMoEMethod(self)
return AWQMoEMethod(self, layer.moe_config)
return None
@classmethod
@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
def __init__(
self,
quant_config: AWQMarlinConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.")
@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `AWQMoEMethod` yet.")
@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,

View File

@ -7,6 +7,7 @@ import torch
from packaging import version
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self)
return BitsAndBytesMoEMethod(self, layer.moe_config)
return None
@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
quant_config: The BitsAndBytes quantization config.
"""
def __init__(self, quant_config: BitsAndBytesConfig):
def __init__(
self,
quant_config: BitsAndBytesConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
try:
import bitsandbytes
if version.parse(
@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err
self.topk_indices_dtype = None
self.quant_config = quant_config
def create_weights(
@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(

View File

@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy)
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales)
@ -58,6 +59,9 @@ __all__ = [
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def __init_(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic."
)
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config)
return CompressedTensorsWNA16MoEMethod(quant_config,
layer.moe_config)
else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod()
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
layer.moe_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self):
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
super().__init__(moe)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin
self.group_size = 16
self.fused_experts = None # type: ignore[assignment]
self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config):
def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
return super().maybe_make_prepare_finalize(moe)
def select_gemm_impl(self, prepare_finalize, moe):
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def apply(
self,
@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.")
@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
if self.use_marlin:
@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
# FlashInfer fused experts path
if self.fused_experts is not None:
return flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
x,
topk_weights,
topk_ids,
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
)
@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype)
@ -385,14 +420,15 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
self.topk_indices_dtype = None
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy
@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.fused_experts = None # type: ignore[assignment]
self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -614,24 +649,30 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
use_batched_format = (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts)
experts: FusedMoEPermuteExpertsUnpermute
num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
num_experts,
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("CutlassBatchedExpertsFp8(%s)",
self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
moe.num_local_experts,
num_dispatchers,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
num_dispatchers=num_dispatchers,
use_batched_format=use_batched_format,
)
self.disable_expert_map = (num_dispatchers > 1
@ -835,8 +876,10 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
hidden_states=x,
@ -976,8 +1022,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,
@ -1280,8 +1331,10 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsWNA16MoEMethod` yet.")
@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,

View File

@ -6,7 +6,8 @@ from typing import Any, Callable, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig):
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self)
return ExpertsInt8MoEMethod(self, layer.moe_config)
return None
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ExpertsInt8Config):
def __init__(
self,
quant_config: ExpertsInt8Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
return Fp8MoEMethod(self, layer.moe_config)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
from vllm.model_executor.layers.fused_moe import fused_experts
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None
@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
elif self.fused_experts is not None:
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig):
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self)
return GGUFMoEMethod(self, layer.moe_config)
return None
@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase):
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
def __init__(
self,
quant_config: GGUFConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GGUFMoEMethod` yet.")
@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids,
layer.w13_qweight_type.weight_type,

View File

@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
def __init__(
self,
quant_config: GPTQMarlinConfig,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe(
x,

View File

@ -12,7 +12,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig):
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return ModelOptFp8MoEMethod(self, layer.moe_config)
return None
@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptFp8Config) -> None:
def __init__(
self,
quant_config: ModelOptFp8Config,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self)
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
return None
@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: NVFP4 Quant Config
"""
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
def __init__(
self,
quant_config: ModelOptNvFp4Config,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support)
super().__init__(moe)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer
@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_swap_experts_impl(
def maybe_make_prepare_finalize(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer:
return
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
return super().maybe_make_prepare_finalize(moe)
# This method update self.fused_experts
# only prepare_finalize is not None call select_gemm_impl
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
# when it's not called(TP case), we still have 2 kernels to use.
def select_gemm_impl(self, prepare_finalize,
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
assert moe is not None and prepare_finalize is not None
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
select_nvfp4_gemm_impl)
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def uses_weight_scale_2_pattern(self) -> bool:
"""
@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe(
@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
n=layer.w2_weight.shape[2] * 2,
k=x.shape[1],
e=layer.w13_weight.shape[0],
device=x.device,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts,
layer,
x,
topk_weights,
topk_ids,
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
)

View File

@ -7,7 +7,7 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig):
else:
raise ValueError("moe_wna16 only support gptq and awq.")
elif isinstance(layer, FusedMoE):
return MoeWNA16Method(self)
return MoeWNA16Method(self, layer.moe_config)
return None
@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
"""
def __init__(self, quant_config: MoeWNA16Config):
def __init__(
self,
quant_config: MoeWNA16Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.")
@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp

View File

@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.use_marlin = self._should_use_marlin()

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE)
@ -25,6 +26,9 @@ __all__ = [
class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod
def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase):
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
module.moe_config)
elif quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
module.moe_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts(
x,
@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
Any]):
def __init__(
self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config
self.input_quant = input_config
@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
out = fused_experts(
x,

View File

@ -10,7 +10,8 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
if isinstance(layer, LinearBase):
return RTNLinearMethod(self)
elif isinstance(layer, FusedMoE):
return RTNMoEMethod(self)
return RTNMoEMethod(self, layer.moe_config)
return None
@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
class RTNMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: RTNConfig):
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `RTNMoEMethod` yet.")
@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size

View File

@ -3,33 +3,30 @@
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from __future__ import annotations
from typing import Optional
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
from vllm.platforms import current_platform
logger = init_logger(__name__)
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
__all__ = [
"is_flashinfer_fp4_cutlass_moe_available",
"reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_kernel",
"flashinfer_fp4_cutlass_moe_forward",
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
]
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.is_device_capability(100))
@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
dim=dim).contiguous())
def build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
experts = FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe_parallel_config.dp_size > 1,
ep_rank=moe_parallel_config.ep_rank,
ep_size=moe_parallel_config.ep_size,
tp_rank=moe_parallel_config.tp_rank,
tp_size=moe_parallel_config.tp_size,
)
logger.debug_once("FlashInferExperts (util)")
return mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
experts,
)
def flashinfer_fp4_cutlass_moe_forward(
fused_experts: mk.FusedMoEModularKernel,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight,
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
extra_expert_args = {
"g1_alphas": layer.g1_alphas,
"g2_alphas": layer.g2_alphas,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
"a1_gscale": a1_gscale,
"a2_gscale": a2_gscale,
"out_dtype": x.dtype,
}
extra_prepare_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
"a1_gscale": a1_gscale,
}
extra_finalize_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
}
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig,
a1_gscale: torch.Tensor,
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
def select_nvfp4_gemm_impl(
moe: FusedMoEConfig,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
allow_flashinfer: bool,
moe, # FusedMoEConfig
logger):
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
# lazy import
from vllm.distributed import get_ep_group
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if allow_flashinfer:
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_backend != "throughput":
raise ValueError(
f"Only throughput backend is supported for FlashInferExperts, "
f"but got {flashinfer_backend}.")
logger.debug_once(
"Initializing FlashInferExperts with throughput backend.")
return FlashInferExperts(
use_nvfp4_w4a4=True,
use_dp=moe.moe_parallel_config.dp_size > 1,
g1_alphas=g1_alphas,
g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=moe.in_dtype,
quant_dtype="nvfp4",
ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank,