[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
bnellnm 2025-08-15 14:46:00 -04:00 committed by GitHub
parent 48b01fd4d4
commit 8ad7285ea2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 2010 additions and 1293 deletions

View File

@ -399,6 +399,7 @@ steps:
- label: Kernels MoE Test %N - label: Kernels MoE Test %N
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental]
source_file_dependencies: source_file_dependencies:
- csrc/quantization/cutlass_w8a8/moe/
- csrc/moe/ - csrc/moe/
- tests/kernels/moe - tests/kernels/moe
- vllm/model_executor/layers/fused_moe/ - vllm/model_executor/layers/fused_moe/

View File

@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
### FusedMoEModularKernel Initialization ### FusedMoEModularKernel Initialization
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are, `FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
* maybe_make_prepare_finalize,
* select_gemm_impl, and * select_gemm_impl, and
* init_prepare_finalize * init_prepare_finalize
#### maybe_make_prepare_finalize
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
Please refer to the implementations in,
* `ModelOptNvFp4FusedMoE`
#### select_gemm_impl #### select_gemm_impl
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object. The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.

View File

@ -70,12 +70,27 @@ def parse_args():
default=64, default=64,
help=("Maximum number of sequences to be processed in a single iteration."), help=("Maximum number of sequences to be processed in a single iteration."),
) )
parser.add_argument(
"--max-model-len",
type=int,
help=("Maximum number of tokens to be processed in a single iteration."),
)
parser.add_argument(
"--timeout",
type=int,
default=300,
help=("Number of seconds before unresponsive process is killed."),
)
parser.add_argument( parser.add_argument(
"--gpu-memory-utilization", "--gpu-memory-utilization",
type=float, type=float,
default=0.8, default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."), help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
) )
parser.add_argument(
"--quantization",
type=str,
)
return parser.parse_args() return parser.parse_args()
@ -90,7 +105,9 @@ def main(
enforce_eager, enforce_eager,
trust_remote_code, trust_remote_code,
max_num_seqs, max_num_seqs,
max_model_len,
gpu_memory_utilization, gpu_memory_utilization,
quantization,
): ):
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
@ -142,7 +159,9 @@ def main(
enable_expert_parallel=True, enable_expert_parallel=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
quantization=quantization,
) )
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# Print the outputs. # Print the outputs.
@ -198,14 +217,16 @@ if __name__ == "__main__":
args.enforce_eager, args.enforce_eager,
args.trust_remote_code, args.trust_remote_code,
args.max_num_seqs, args.max_num_seqs,
args.max_model_len,
args.gpu_memory_utilization, args.gpu_memory_utilization,
args.quantization,
), ),
) )
proc.start() proc.start()
procs.append(proc) procs.append(proc)
exit_code = 0 exit_code = 0
for proc in procs: for proc in procs:
proc.join(timeout=300) proc.join(timeout=args.timeout)
if proc.exitcode is None: if proc.exitcode is None:
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.") print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
proc.kill() proc.kill()

View File

@ -7,41 +7,22 @@ import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
# Fused experts and PrepareFinalize imports from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from .mk_objects import (expert_info, make_fused_experts,
make_prepare_finalize, prepare_finalize_info)
from .parallel_utils import ProcessGroupInfo from .parallel_utils import ProcessGroupInfo
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
make_quant_fp8_weights, per_token_cast_to_fp8)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
@ -69,24 +50,31 @@ class Config:
torch_trace_dir_path: Optional[str] = None torch_trace_dir_path: Optional[str] = None
def __post_init__(self):
if self.quant_config is None:
self.quant_config = FusedMoEQuantConfig()
def describe(self) -> str: def describe(self) -> str:
s = "" s = ""
s += "== Config: \n" s += "== Config:\n"
s += f" world_size={self.world_size} \n" s += f" world_size={self.world_size}\n"
s += f" PF={self.prepare_finalize_type.__name__} \n" s += f" PF={self.prepare_finalize_type.__name__}\n"
s += f" FE={self.fused_experts_type.__name__} \n" s += f" FE={self.fused_experts_type.__name__}\n"
s += f" topk={self.topks} \n" s += f" E={self.E}\n"
s += f" dtype={self.dtype} \n" s += f" Ms={self.Ms}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" s += f" N={self.N}\n"
s += " Quant: \n" s += f" K={self.K}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " s += f" topk={self.topks}\n"
s += f" dtype={self.dtype}\n"
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
s += " Quant:\n"
if self.quant_config is not None: if self.quant_config is not None:
s += f" q_dtype={self.quant_dtype} \n" s += f" q_dtype={self.quant_dtype}\n"
s += f" q_block_shape={self.quant_block_shape} \n" s += f" q_block_shape={self.quant_block_shape}\n"
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
s += f" q_per_act_token={self.is_per_act_token_quant} \n" s += f" q_per_act_token={self.is_per_act_token_quant}\n"
else: else:
s += " quant=None \n" s += " quant=None\n"
return s return s
@property @property
@ -95,34 +83,28 @@ class Config:
return self.Ms return self.Ms
@property @property
def quant_dtype(self) -> Optional[torch.dtype]: def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is None: assert self.quant_config is not None
return None
return self.quant_config.quant_dtype return self.quant_config.quant_dtype
@property @property
def is_per_act_token_quant(self) -> bool: def is_per_act_token_quant(self) -> bool:
if self.quant_config is None: assert self.quant_config is not None
return False
return self.quant_config.per_act_token_quant return self.quant_config.per_act_token_quant
@property @property
def is_per_tensor_act_quant(self) -> bool: def is_per_tensor_act_quant(self) -> bool:
if self.quant_config is None:
return False
return (not self.is_per_act_token_quant return (not self.is_per_act_token_quant
and self.quant_block_shape is None) and self.quant_block_shape is None)
@property @property
def is_per_out_ch_quant(self) -> bool: def is_per_out_ch_quant(self) -> bool:
if self.quant_config is None: assert self.quant_config is not None
return False
return self.quant_config.per_out_ch_quant return self.quant_config.per_out_ch_quant
@property @property
def quant_block_shape(self) -> Optional[list[int]]: def quant_block_shape(self) -> Optional[list[int]]:
if self.quant_config is None: assert self.quant_config is not None
return None
return self.quant_config.block_shape return self.quant_config.block_shape
@property @property
@ -130,36 +112,30 @@ class Config:
assert isinstance(self.topks, int) assert isinstance(self.topks, int)
return self.topks return self.topks
@property
def topk_ids_dtype(self) -> Optional[torch.dtype]:
topk_ids_dtype = None
if self.prepare_finalize_type == PplxPrepareAndFinalize:
topk_ids_dtype = torch.uint32
elif self.prepare_finalize_type in [
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
]:
topk_ids_dtype = torch.int64
return topk_ids_dtype
@property @property
def num_local_experts(self) -> int: def num_local_experts(self) -> int:
return self.E // self.world_size return self.E // self.world_size
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
""" """
make env data for vllm launch. make env data for vllm launch.
""" """
vllm_config = VllmConfig() vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = self.world_size vllm_config.parallel_config.data_parallel_size = self.world_size
vllm_config.parallel_config.enable_expert_parallel = True vllm_config.parallel_config.enable_expert_parallel = True
env_dict = { env_dict = {
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
} }
backend = self.all2all_backend()
if backend is not None:
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
if self.fused_moe_chunk_size is not None: if self.fused_moe_chunk_size is not None:
env_dict.update( env_dict.update(
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
return vllm_config, env_dict return vllm_config, env_dict
def is_fp8_block_quantized(self): def is_fp8_block_quantized(self):
@ -167,85 +143,59 @@ class Config:
and self.quant_block_shape is not None) and self.quant_block_shape is not None)
def is_batched_prepare_finalize(self): def is_batched_prepare_finalize(self):
return self.prepare_finalize_type in [ info = prepare_finalize_info(self.prepare_finalize_type)
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize return (mk.FusedMoEActivationFormat.BatchedExperts ==
] info.activation_format)
def is_batched_fused_experts(self): def is_batched_fused_experts(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, return (mk.FusedMoEActivationFormat.BatchedExperts ==
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts info.activation_format)
]
def is_standard_fused_experts(self): def is_standard_fused_experts(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, return mk.FusedMoEActivationFormat.Standard == info.activation_format
TritonExperts
]
def is_fe_16bit_supported(self): def fe_supported_types(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, return info.supported_dtypes
NaiveBatchedExperts, TritonExperts
]
def is_fe_fp8_supported(self): def pf_supported_types(self):
return self.fused_experts_type in [ info = prepare_finalize_info(self.prepare_finalize_type)
BatchedDeepGemmExperts, return info.supported_dtypes
BatchedTritonExperts,
BatchedTritonOrDeepGemmExperts,
CutlassExpertsFp8,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
NaiveBatchedExperts,
]
def is_fe_block_fp8_supported(self): def is_block_quant_supported(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
BatchedDeepGemmExperts, return info.blocked_quantization_support
BatchedTritonOrDeepGemmExperts,
DeepGemmExperts,
TritonExperts,
TritonOrDeepGemmExperts,
BatchedTritonExperts,
NaiveBatchedExperts,
]
def is_fe_supports_chunking(self): def is_fe_supports_chunking(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, return info.supports_chunking
TritonExperts
] def supports_expert_map(self):
info = expert_info(self.fused_experts_type)
return info.supports_expert_map
def supports_apply_weight_on_input(self):
info = prepare_finalize_info(self.prepare_finalize_type)
return info.supports_apply_weight_on_input
def needs_deep_gemm(self): def needs_deep_gemm(self):
return self.fused_experts_type in [ info = expert_info(self.fused_experts_type)
BatchedDeepGemmExperts, return info.needs_deep_gemm
DeepGemmExperts,
]
def needs_pplx(self): def needs_pplx(self):
return self.prepare_finalize_type in [PplxPrepareAndFinalize] info = prepare_finalize_info(self.prepare_finalize_type)
return info.backend == "pplx"
def needs_deep_ep(self): def needs_deep_ep(self):
return self.prepare_finalize_type in [ info = prepare_finalize_info(self.prepare_finalize_type)
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize return (info.backend == "deepep_high_throughput"
] or info.backend == "deepep_low_latency")
def all2all_backend(self): def all2all_backend(self):
if self.needs_pplx(): info = prepare_finalize_info(self.prepare_finalize_type)
return "pplx" return info.backend
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
return "deepep_high_throughput"
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
return "deepep_low_latency"
return "naive"
def needs_all2all(self):
return self.prepare_finalize_type in [
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize
]
def is_valid(self): def is_valid(self):
# Check prepare-finalize and fused-experts compatibility # Check prepare-finalize and fused-experts compatibility
@ -267,28 +217,28 @@ class Config:
# invalid quant config # invalid quant config
return False return False
# check bf16 / fp16 support # check type support
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) if self.quant_dtype is None:
if is_16bit and not self.is_fe_16bit_supported(): if (self.dtype not in self.pf_supported_types()
return False or self.dtype not in self.fe_supported_types()):
return False
else:
if (self.quant_dtype not in self.pf_supported_types()
or self.quant_dtype not in self.fe_supported_types()):
return False
# Check fp8 support # Check block quanization support
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
if is_fp8 and not self.is_fe_fp8_supported():
return False
# Check fp8 block quanization support
is_block_quatized = self.quant_block_shape is not None is_block_quatized = self.quant_block_shape is not None
if is_block_quatized and not is_fp8: if is_block_quatized and self.quant_dtype is None:
return False return False
if is_block_quatized and not self.is_fe_block_fp8_supported(): if is_block_quatized and not self.is_block_quant_supported():
return False return False
# deep_gemm only works with block-quantized # deep_gemm only works with block-quantized
if self.needs_deep_gemm() and not is_block_quatized: if self.needs_deep_gemm() and not is_block_quatized:
return False return False
# Check dependencies # Check dependencies (turn into asserts?)
if self.needs_deep_ep() and not has_deep_ep(): if self.needs_deep_ep() and not has_deep_ep():
return False return False
if self.needs_deep_gemm() and not has_deep_gemm(): if self.needs_deep_gemm() and not has_deep_gemm():
@ -305,6 +255,8 @@ class WeightTensors:
w2: torch.Tensor w2: torch.Tensor
w1_scale: Optional[torch.Tensor] w1_scale: Optional[torch.Tensor]
w2_scale: Optional[torch.Tensor] w2_scale: Optional[torch.Tensor]
w1_gs: Optional[torch.Tensor] = None
w2_gs: Optional[torch.Tensor] = None
def describe(self): def describe(self):
s = "" s = ""
@ -313,13 +265,20 @@ class WeightTensors:
s += f' - {_describe_tensor(self.w2, "w2")} \n' s += f' - {_describe_tensor(self.w2, "w2")} \n'
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
return s return s
def is_quantized(self) -> bool:
# or w1_scale is not None?
return (self.w1.dtype == torch.float8_e4m3fn
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
def to_current_device(self): def to_current_device(self):
self.w1 = self.w1.to(device=torch.cuda.current_device()) self.w1 = self.w1.to(device=torch.cuda.current_device())
self.w2 = self.w2.to(device=torch.cuda.current_device()) self.w2 = self.w2.to(device=torch.cuda.current_device())
is_quantized = self.w1.dtype == torch.float8_e4m3fn
if is_quantized: if self.is_quantized():
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
self.w1_scale = self.w1_scale.to( self.w1_scale = self.w1_scale.to(
@ -327,56 +286,51 @@ class WeightTensors:
self.w2_scale = self.w2_scale.to( self.w2_scale = self.w2_scale.to(
device=torch.cuda.current_device()) device=torch.cuda.current_device())
if self.w1_gs is not None:
assert self.w2_gs is not None
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
def slice_weights(self, rank: int, def slice_weights(self, rank: int,
num_local_experts: int) -> "WeightTensors": num_local_experts: int) -> "WeightTensors":
s = rank * num_local_experts s = rank * num_local_experts
e = s + num_local_experts e = s + num_local_experts
w1 = self.w1[s:e, :, :] w1 = self.w1[s:e, :, :]
w2 = self.w2[s:e, :, :] w2 = self.w2[s:e, :, :]
is_quantized = self.w1.dtype == torch.float8_e4m3fn
w1_scale, w2_scale = (None, None) w1_scale, w2_scale = (None, None)
if is_quantized: if self.is_quantized():
assert self.w1_scale is not None assert self.w1_scale is not None
assert self.w2_scale is not None assert self.w2_scale is not None
w1_scale = self.w1_scale[s:e, :, :] w1_scale = self.w1_scale[s:e, :, :]
w2_scale = self.w2_scale[s:e, :, :] w2_scale = self.w2_scale[s:e, :, :]
return WeightTensors(w1, w2, w1_scale, w2_scale)
w1_gs = self.w1_gs
w2_gs = self.w2_gs
if w1_gs is not None:
assert w2_gs is not None
w1_gs = w1_gs[s:e]
w2_gs = w2_gs[s:e]
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
@staticmethod @staticmethod
def make(config: Config) -> "WeightTensors": def make(config: Config) -> "WeightTensors":
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
if config.quant_dtype is None:
# just make normal dtype weights
w1, w2 = make_non_quant_weights(e=config.E,
n=config.N,
k=config.K,
dtype=config.dtype)
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
assert config.quant_dtype == torch.float8_e4m3fn
if not config.is_fp8_block_quantized():
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
e=config.E,
n=config.N,
k=config.K,
per_out_channel_quant=config.is_per_out_ch_quant,
)
return WeightTensors(w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale)
assert config.quant_block_shape is not None
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
e=config.E, e=config.E,
n=config.N, n=config.N,
k=config.K, k=config.K,
block_size=config.quant_block_shape, in_dtype=config.dtype,
quant_dtype=config.quant_dtype,
block_shape=config.quant_block_shape,
per_act_token_quant=config.is_per_out_ch_quant,
) )
return WeightTensors(w1=w1, return WeightTensors(w1=w1,
w2=w2, w2=w2,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale) w2_scale=w2_scale,
w1_gs=w1_gs,
w2_gs=w2_gs)
@dataclass @dataclass
@ -449,7 +403,6 @@ class RankTensors:
dtype=dtype) dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
False) False)
topk_ids = topk_ids.to(config.topk_ids_dtype)
# distribute topk_ids evenly # distribute topk_ids evenly
for mi in range(m): for mi in range(m):
@ -457,7 +410,7 @@ class RankTensors:
topk_ids = topk_ids.to(device=torch.cuda.current_device()) topk_ids = topk_ids.to(device=torch.cuda.current_device())
expert_map = None expert_map = None
if config.world_size > 1: if config.world_size > 1 and config.supports_expert_map():
expert_map = torch.full((global_num_experts, ), expert_map = torch.full((global_num_experts, ),
fill_value=-1, fill_value=-1,
dtype=torch.int32) dtype=torch.int32)
@ -480,92 +433,100 @@ class RankTensors:
def reference_moe_impl(config: Config, weights: WeightTensors, def reference_moe_impl(config: Config, weights: WeightTensors,
rank_tensors: RankTensors) -> torch.Tensor: rank_tensors: RankTensors) -> torch.Tensor:
return torch_experts(a=rank_tensors.hidden_states, if config.quant_dtype == "nvfp4":
w1=weights.w1, quant_blocksize = 16
w2=weights.w2, dtype = config.dtype
w1_q = weights.w1
w1_blockscale = weights.w1_scale
w1_gs = weights.w1_gs
w2_q = weights.w2
w2_blockscale = weights.w2_scale
w2_gs = weights.w2_gs
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
assert w1_blockscale.shape[1] % 128 == 0
assert w1_blockscale.shape[2] % 4 == 0
assert w2_blockscale.shape[1] % 128 == 0
assert w2_blockscale.shape[2] % 4 == 0
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
rank_tensors.hidden_states, a_global_scale)
a = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=dtype,
device=a_fp4.device,
block_size=quant_blocksize)
e = w1_q.shape[0]
n = w1_q.shape[1] // 2
k = w2_q.shape[1]
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
a_scale = None
w1_scale = None
w2_scale = None
quant_dtype = None
per_act_token_quant = False
block_shape = None
else:
a = rank_tensors.hidden_states
a_scale = rank_tensors.hidden_states_scale
w1 = weights.w1
w1_scale = weights.w1_scale
w2 = weights.w2
w2_scale = weights.w2_scale
quant_dtype = config.quant_dtype
per_act_token_quant = config.is_per_act_token_quant
block_shape = config.quant_block_shape
return torch_experts(a=a,
w1=w1,
w2=w2,
topk_weight=rank_tensors.topk_weights, topk_weight=rank_tensors.topk_weights,
topk_ids=rank_tensors.topk_ids, topk_ids=rank_tensors.topk_ids,
global_num_experts=config.E, global_num_experts=config.E,
expert_map=None, expert_map=None,
w1_scale=weights.w1_scale, w1_scale=w1_scale,
w2_scale=weights.w2_scale, w2_scale=w2_scale,
a1_scale=rank_tensors.hidden_states_scale, a1_scale=a_scale,
quant_dtype=config.quant_dtype, quant_dtype=quant_dtype,
per_act_token_quant=config.is_per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=config.quant_block_shape, block_shape=block_shape,
apply_router_weights_on_input=config.topk == 1) apply_router_weights_on_input=config.topk == 1
and config.supports_apply_weight_on_input())
def make_fused_experts( def make_modular_kernel(
config: Config, moe: FusedMoEConfig, config: Config,
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: vllm_config: VllmConfig,
weights: WeightTensors,
use_fp8 = config.quant_dtype == torch.float8_e4m3fn ) -> mk.FusedMoEModularKernel:
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if config.fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": config.quant_block_shape,
"per_act_token_quant": config.is_per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif config.fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif config.fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif config.fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif config.fused_experts_type == CutlassExpertsFp8:
use_batched_format = config.is_batched_prepare_finalize()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
kwargs = {
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": config.is_per_act_token_quant,
"per_out_ch_quant": config.is_per_out_ch_quant,
"block_shape": config.quant_block_shape,
"num_dispatchers": num_dispatchers,
"use_batched_format": use_batched_format
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
return experts
def make_modular_kernel(config: Config,
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
def next_power_of_2(x): def next_power_of_2(x):
import math import math
@ -579,6 +540,7 @@ def make_modular_kernel(config: Config,
dp_size_=get_dp_group().world_size, dp_size_=get_dp_group().world_size,
vllm_parallel_config=vllm_config.parallel_config, vllm_parallel_config=vllm_config.parallel_config,
) )
moe = FusedMoEConfig( moe = FusedMoEConfig(
num_experts=config.E, num_experts=config.E,
experts_per_token=config.topk, experts_per_token=config.topk,
@ -591,15 +553,16 @@ def make_modular_kernel(config: Config,
) )
# make modular kernel # make modular kernel
prepare_finalize = None prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
if config.needs_all2all(): config.all2all_backend(), moe)
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
else:
prepare_finalize = MoEPrepareAndFinalizeNoEP()
fused_experts = make_fused_experts(config, moe, fused_experts = make_fused_experts(
prepare_finalize.num_dispatchers()) config.fused_experts_type,
moe,
prepare_finalize.num_dispatchers(),
weights.w1_gs,
weights.w2_gs,
)
modular_kernel = mk.FusedMoEModularKernel( modular_kernel = mk.FusedMoEModularKernel(
prepare_finalize=prepare_finalize, fused_experts=fused_experts) prepare_finalize=prepare_finalize, fused_experts=fused_experts)
@ -620,22 +583,45 @@ def run_modular_kernel(
# weights for rank # weights for rank
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
mk = make_modular_kernel(config, vllm_config) mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = { mk_kwargs = {
"hidden_states": rank_tensors.hidden_states.clone( "hidden_states":
rank_tensors.hidden_states.clone(
), # impls might update the tensor in place ), # impls might update the tensor in place
"w1": rank_weights.w1, "w1":
"w2": rank_weights.w2, rank_weights.w1,
"topk_weights": rank_tensors.topk_weights, "w2":
"topk_ids": rank_tensors.topk_ids, rank_weights.w2,
"expert_map": rank_tensors.expert_map, "topk_weights":
"w1_scale": rank_weights.w1_scale, rank_tensors.topk_weights,
"w2_scale": rank_weights.w2_scale, "topk_ids":
"a1_scale": rank_tensors.hidden_states_scale, rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
"global_num_experts": config.E, "expert_map":
"apply_router_weight_on_input": config.topk == 1, rank_tensors.expert_map,
"w1_scale":
rank_weights.w1_scale,
"w2_scale":
rank_weights.w2_scale,
"a1_scale":
rank_tensors.hidden_states_scale,
"global_num_experts":
config.E,
"apply_router_weight_on_input":
config.topk == 1 and config.supports_apply_weight_on_input(),
} }
out = mk.forward(**mk_kwargs)
num_tokens = rank_tensors.hidden_states.shape[0]
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
device="cuda",
dtype=torch.int)
with set_forward_context(
None,
vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
):
out = mk.forward(**mk_kwargs)
return out return out

View File

@ -1,58 +1,316 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch import torch
# Fused experts and PrepareFinalize imports # Fused experts and PrepareFinalize imports
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts) BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts) BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 FusedMoEQuantConfig)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts, NaiveBatchedExperts) BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
TritonExperts)
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts) TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_pplx from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if has_deep_ep():
@dataclass
class PrepareFinalizeInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
backend: Optional[str]
supports_apply_weight_on_input: bool = True
@dataclass
class ExpertInfo:
activation_format: mk.FusedMoEActivationFormat
supported_dtypes: list[Union[torch.dtype, str]]
blocked_quantization_support: bool
supports_chunking: bool
supports_expert_map: bool
needs_matching_quant: bool = False
needs_deep_gemm: bool = False
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
PrepareFinalizeInfo] = {}
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
standard_format = mk.FusedMoEActivationFormat.Standard
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
common_float_types: list[Union[torch.dtype, str]] = [
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
]
common_float_and_int_types = common_float_types + [torch.int8]
nv_fp4_types = ["nvfp4"]
fp8_types = [torch.float8_e4m3fn]
def register_prepare_and_finalize(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
backend: Optional[str],
force_multigpu: bool = False,
supports_apply_weight_on_input: bool = True,
):
global PREPARE_FINALIZE_INFO
global MK_ALL_PREPARE_FINALIZE_TYPES
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
assert kind not in PREPARE_FINALIZE_INFO
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
backend,
supports_apply_weight_on_input,
)
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
if backend is not None or force_multigpu:
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
else:
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
def register_experts(
kind,
activation_format: mk.FusedMoEActivationFormat,
supported_dtypes: list[Union[torch.dtype, str]],
blocked_quantization_support: bool,
supports_chunking: bool,
supports_expert_map: bool,
needs_matching_quant: bool = False,
needs_deep_gemm: bool = False,
):
global EXPERT_INFO
global MK_FUSED_EXPERT_TYPES
assert kind not in EXPERT_INFO
EXPERT_INFO[kind] = ExpertInfo(
activation_format,
supported_dtypes,
blocked_quantization_support,
supports_chunking,
supports_expert_map,
needs_matching_quant,
needs_deep_gemm,
)
MK_FUSED_EXPERT_TYPES.append(kind)
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
info = PREPARE_FINALIZE_INFO.get(kind)
assert info is not None
return info
def expert_info(kind) -> ExpertInfo:
info = EXPERT_INFO.get(kind)
assert info is not None
return info
register_prepare_and_finalize(
MoEPrepareAndFinalizeNoEP,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend=None,
)
register_experts(
BatchedTritonExperts,
batched_format,
common_float_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
)
register_experts(
TritonExperts,
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
)
register_experts(
NaiveBatchedExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=True,
)
# Disable on blackwell for now
if has_deep_ep() and not current_platform.has_device_capability(100):
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize) DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
register_prepare_and_finalize(
DeepEPHTPrepareAndFinalize,
standard_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_high_throughput",
)
register_prepare_and_finalize(
DeepEPLLPrepareAndFinalize,
batched_format,
common_float_types,
blocked_quantization_support=True,
backend="deepep_low_latency",
)
if has_pplx(): if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize) PplxPrepareAndFinalize)
register_prepare_and_finalize(
PplxPrepareAndFinalize,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
backend="pplx",
)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] if (has_flashinfer_cutlass_fused_moe()
if has_pplx(): and current_platform.has_device_capability(100)):
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
if has_deep_ep(): FlashInferExperts)
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize FlashInferCutlassMoEPrepareAndFinalize)
]
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
backend=None,
force_multigpu=True,
supports_apply_weight_on_input=False,
)
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + register_experts(
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) FlashInferExperts,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
# Note: this is a hack to get it to run for now
supports_expert_map=True,
)
else:
FlashInferCutlassMoEPrepareAndFinalize = None
MK_FUSED_EXPERT_TYPES = [ if has_deep_gemm() and is_deep_gemm_supported():
BatchedDeepGemmExperts, register_experts(
BatchedTritonExperts, BatchedDeepGemmExperts,
NaiveBatchedExperts, batched_format,
BatchedTritonOrDeepGemmExperts, fp8_types,
CutlassExpertsFp8, blocked_quantization_support=True,
DeepGemmExperts, supports_chunking=False,
TritonOrDeepGemmExperts, supports_expert_map=False,
TritonExperts, needs_matching_quant=False,
] needs_deep_gemm=True,
)
register_experts(
DeepGemmExperts,
standard_format,
fp8_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=False,
needs_deep_gemm=True,
),
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts,
standard_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=True,
needs_matching_quant=True,
needs_deep_gemm=True,
)
if cutlass_fp8_supported():
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
CutlassExpertsFp8)
register_experts(
CutlassExpertsFp8,
standard_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=True,
supports_expert_map=False,
)
register_experts(
CutlassBatchedExpertsFp8,
batched_format,
fp8_types,
blocked_quantization_support=False,
supports_chunking=False,
supports_expert_map=False,
)
if cutlass_fp4_supported():
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4)
register_experts(
CutlassExpertsFp4,
standard_format,
nv_fp4_types,
blocked_quantization_support=True,
supports_chunking=True,
supports_expert_map=False,
)
MK_QUANT_CONFIGS = [ MK_QUANT_CONFIGS = [
None, None,
@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [
# block-quantized weights and per-token activations # block-quantized weights and per-token activations
# block-quantized weights and per-tensor activations # block-quantized weights and per-tensor activations
] ]
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
MK_QUANT_CONFIGS += [
FusedMoEQuantConfig(quant_dtype="nvfp4",
per_out_ch_quant=False,
per_act_token_quant=False,
block_shape=None),
]
def _make_gscale(num_experts: int) -> torch.Tensor:
return torch.ones((num_experts, ),
device=torch.cuda.current_device(),
dtype=torch.float32)
def make_prepare_finalize(
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
backend: Optional[str],
moe: FusedMoEConfig,
) -> mk.FusedMoEPrepareAndFinalize:
if backend != "naive" and backend is not None:
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
use_dp=moe.moe_parallel_config.dp_size > 1,
a1_gscale=_make_gscale(moe.num_local_experts),
)
else:
return MoEPrepareAndFinalizeNoEP()
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
s = rank * num_local_experts
e = s + num_local_experts
return t[s:e]
def make_fused_experts(
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
moe: FusedMoEConfig,
num_dispatchers: int,
w1_gs: Optional[torch.Tensor],
w2_gs: Optional[torch.Tensor],
) -> mk.FusedMoEPermuteExpertsUnpermute:
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
batch_kwargs = {
"max_num_tokens": moe.max_num_tokens,
"num_dispatchers": num_dispatchers,
}
quant_kwargs = {
"use_fp8_w8a8": use_fp8,
"use_int8_w8a8": False,
"use_int8_w8a16": False,
"use_int4_w4a16": False,
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
if fused_experts_type == BatchedDeepGemmExperts:
kwargs = batch_kwargs | {
"block_shape": moe.block_shape,
"per_act_token_quant": moe.per_act_token_quant,
}
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
experts = BatchedDeepGemmExperts(**kwargs)
elif fused_experts_type == BatchedTritonExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print("Making DeepGemmExperts () ...")
experts = DeepGemmExperts()
elif fused_experts_type == TritonExperts:
kwargs = quant_kwargs
print(f"Making TritonExperts {kwargs} ...")
experts = TritonExperts(**kwargs)
elif fused_experts_type == TritonOrDeepGemmExperts:
kwargs = quant_kwargs | deepgemm_kwargs
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
experts = TritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == NaiveBatchedExperts:
kwargs = batch_kwargs | quant_kwargs
print(f"Making NaiveBatchedExperts {kwargs} ...")
experts = NaiveBatchedExperts(**kwargs)
elif fused_experts_type == CutlassExpertsFp8:
kwargs = {
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassExpertsFp8 {kwargs} ...")
experts = CutlassExpertsFp8(**kwargs)
elif fused_experts_type == CutlassBatchedExpertsFp8:
kwargs = {
"max_experts_per_worker": moe.num_local_experts,
"num_dispatchers": num_dispatchers,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
}
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
experts = CutlassBatchedExpertsFp8(**kwargs)
elif fused_experts_type == CutlassExpertsFp4:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"max_experts_per_worker": num_experts,
"out_dtype": moe.in_dtype,
"per_act_token_quant": moe.per_act_token_quant,
"per_out_ch_quant": moe.per_out_ch_quant,
"block_shape": moe.block_shape,
"num_dispatchers": num_dispatchers,
}
print(f"Making CutlassExpertsFp4 {kwargs} ...")
experts = CutlassExpertsFp4(**kwargs)
elif fused_experts_type == FlashInferExperts:
assert w1_gs is not None and w2_gs is not None
num_experts = moe.num_local_experts
rank = moe.moe_parallel_config.dp_rank
kwargs = {
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
"a1_gscale": _make_gscale(num_experts),
"a2_gscale": _make_gscale(num_experts),
"out_dtype": moe.in_dtype,
"quant_dtype": "nvfp4",
"ep_rank": moe.ep_rank,
"ep_size": moe.ep_size,
"tp_rank": moe.tp_rank,
"tp_size": moe.tp_size,
}
print(f"Making FlashInferExperts {kwargs} ...")
experts = FlashInferExperts(**kwargs)
else:
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
return experts

View File

@ -52,7 +52,7 @@ def profile_modular_kernel(
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
# make modular kernel # make modular kernel
mk = make_modular_kernel(config, vllm_config) mk = make_modular_kernel(config, vllm_config, weights)
mk_kwargs = { mk_kwargs = {
"hidden_states": rank_tensors.hidden_states, "hidden_states": rank_tensors.hidden_states,
@ -83,7 +83,7 @@ def rank_worker(
# sanity check # sanity check
from vllm import envs from vllm import envs
if config.fused_moe_chunk_size is not None: if config.fused_moe_chunk_size is not None:
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
# get weights to this device # get weights to this device
weights.to_current_device() weights.to_current_device()

View File

@ -1,117 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm._custom_ops as ops
from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def make_non_quant_weights(
e: int,
n: int,
k: int,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2
"""
device = torch.cuda.current_device()
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
return w1, w2
def make_block_quant_fp8_weights(
e: int,
n: int,
k: int,
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1_scale, w2_scale
"""
dtype = torch.bfloat16
device = torch.cuda.current_device()
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device=device,
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device=device,
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
(k + (block_k - 1)) // block_k)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
block_size=[block_k, block_n])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
block_size=[block_k, block_n])
return w1, w2, w1_s, w2_s
def make_quant_fp8_weights(
e: int,
n: int,
k: int,
per_out_channel_quant: bool,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return w1, w2, w1_scale, w2_scale
"""
q_dtype = torch.float8_e4m3fn
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
# w1 -> w1_q, w2 -> w2_q
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
n_b_scales = 2 * n if per_out_channel_quant else 1
k_b_scales = k if per_out_channel_quant else 1
w1_scale = torch.empty((e, n_b_scales, 1),
device="cuda",
dtype=torch.float32)
w2_scale = torch.empty((e, k_b_scales, 1),
device="cuda",
dtype=torch.float32)
for expert in range(e):
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
return w1_q, w2_q, w1_scale, w2_scale

View File

@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
) )
B, B_q, B_scale, _, _, _ = make_test_weights( (B, B_q, B_scale, _), _ = make_test_weights(
num_experts, num_experts,
N // 2, N // 2,
K, K,
@ -243,7 +243,7 @@ def test_fused_moe_batched_experts(
act_dtype = dtype act_dtype = dtype
quant_dtype = None quant_dtype = None
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights( (w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,

View File

@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
N, _) = make_test_weights(E,
K, N,
dtype, K,
torch.float8_e4m3fn, dtype,
per_act_token_quant=False, torch.float8_e4m3fn,
block_shape=block_size) per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False, use_int8_w8a8=False,
use_int8_w8a16=False, use_int8_w8a16=False,
use_int4_w4a16=False, use_int4_w4a16=False,
use_mxfp4_w4a4=False,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=block_size) block_shape=block_size)
@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
N, _) = make_test_weights(E,
K, N,
dtype, K,
torch.float8_e4m3fn, dtype,
per_act_token_quant=False, torch.float8_e4m3fn,
block_shape=block_size) per_act_token_quant=False,
block_shape=block_size)
# Note: for now use_compile will error out if the problem size is # Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and # large enough to trigger chunking. I'm leaving the flag and

View File

@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
a = torch.randn((M, K), dtype=dtype) / 10 a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E, (_, w1, w1_s, _), (_, w2, w2_s,
N, _) = make_test_weights(E,
K, N,
dtype, K,
torch.int8, dtype,
per_act_token_quant=False, torch.int8,
block_shape=block_size) per_act_token_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam. # Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):

View File

@ -9,6 +9,7 @@ import random
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import per_token_cast_to_fp8
from tests.kernels.utils import baseline_scaled_mm from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -16,20 +17,6 @@ from vllm.utils import cdiv
from vllm.utils.deep_gemm import per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
def per_token_cast_to_fp8(
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view *
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [ @pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
(4, 8192, 7168, 4096), (4, 8192, 7168, 4096),
(4, 8192, 2048, 7168), (4, 8192, 2048, 7168),
@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm(
device=device, device=device,
dtype=torch.float)) dtype=torch.float))
for i in range(num_groups): for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
for i in range(num_groups): for i in range(num_groups):
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]] a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]

View File

@ -70,8 +70,10 @@ def make_block_quant_fp8_weights(
""" """
Return weights w1q, w2q, w1_scale, w2_scale Return weights w1q, w2q, w1_scale, w2_scale
""" """
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights( (_, w1q, w1_scale, _), (_, w2q, w2_scale,
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size) _) = make_test_weights(e, n, k, torch.bfloat16,
torch.float8_e4m3fn,
block_size)
return w1q, w2q, w1_scale, w2_scale return w1q, w2q, w1_scale, w2_scale

View File

@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
# Note: W1 has shape (E, 2N, K), so N = 512 # Note: W1 has shape (E, 2N, K), so N = 512
# can trigger the deepgemm path. # can trigger the deepgemm path.
MNKs = [ MNKs = [
(1024, 512, 128), (1024, 768, 128),
(1024, 512, 512), (1024, 768, 512),
(2048, 512, 512), (2048, 768, 512),
(512, 1024, 1024), (512, 1024, 1024),
(512, 2048, 2048), (512, 2048, 2048),
(4096, 4096, 1024), (4096, 4096, 1024),

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
if not has_flashinfer_cutlass_fused_moe(
) or not current_platform.has_device_capability(100):
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
allow_module_level=True)
MNK_FACTORS = [
(2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024),
(64, 2048, 1536),
(224, 1024, 1024),
(224, 1024, 1536),
]
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256])
#@pytest.mark.parametrize("e", [128, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
dtype: torch.dtype):
current_platform.seed_everything(7)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
(_, w1_q, w1_blockscale,
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
e,
n,
k,
in_dtype=dtype,
quant_dtype="nvfp4",
block_shape=None, # use quant_blocksize?
per_act_token_quant=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a,
score,
topk,
renormalize=False)
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
FlashInferExperts(
a1_gscale=a1_gs,
g1_alphas=(1 / w1_gs),
a2_gscale=a2_gs,
g2_alphas=(1 / w2_gs),
out_dtype=dtype,
quant_dtype="nvfp4",
))
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w1_scale=w1_blockscale,
w2=w2_q,
w2_scale=w2_blockscale,
a1_scale=a1_gs,
a2_scale=a2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved,
a_global_scale,
dtype=a.dtype,
device=a.device,
block_size=quant_blocksize)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=dtype,
device=w1_q.device,
block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=dtype,
device=w2_q.device,
block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch.testing.assert_close(torch_output,
flashinfer_output,
atol=1e-1,
rtol=1e-1)
if __name__ == "__main__":
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
import textwrap
import traceback
from itertools import product from itertools import product
from typing import Optional from typing import Optional
@ -10,41 +12,51 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
reference_moe_impl, reference_moe_impl,
run_modular_kernel) run_modular_kernel)
from .modular_kernel_tools.mk_objects import ( from .modular_kernel_tools.mk_objects import (
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
parallel_launch_with_config) parallel_launch_with_config)
# TODO (varun): These requirements are very strict and could be relaxed. has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) or has_flashinfer_cutlass_fused_moe())
meets_package_requirements = pytest.mark.skipif( meets_multi_gpu_requirements = pytest.mark.skipif(
not has_all_packages, not has_any_multi_gpu_package,
reason="Requires deep_ep & deep_gemm & pplx packages", reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
) )
def format_result(verbose, msg, ex=None):
if ex is not None:
x = str(ex)
newx = x.strip(" \n\t")[:16]
if len(newx) < len(x):
newx = newx + " ..."
prefix = "E\t"
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
print(f"FAILED {msg} - {newx}\n")
elif verbose:
print(f"PASSED {msg}")
else:
print(".", end="")
def rank_worker( def rank_worker(
pgi: ProcessGroupInfo, pgi: ProcessGroupInfo,
vllm_config: VllmConfig, vllm_config: VllmConfig,
cpu_group, cpu_group,
config: Config, config: Config,
weights: WeightTensors, weights: WeightTensors,
verbose: bool,
): ):
current_platform.seed_everything(pgi.rank) current_platform.seed_everything(pgi.rank)
@ -61,39 +73,64 @@ def rank_worker(
TOPKs = config.topks TOPKs = config.topks
assert isinstance(TOPKs, list) assert isinstance(TOPKs, list)
exceptions = []
count = 0
for m, topk in product(Ms, TOPKs): for m, topk in product(Ms, TOPKs):
print(f"Running m={m}, topk={topk} ...") try:
# override m and topk print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
cfgx = copy.deepcopy(config) count = count + 1
cfgx.Ms = m # override m and topk
cfgx.topks = topk cfgx = copy.deepcopy(config)
cfgx.Ms = m
cfgx.topks = topk
# inputs for rank # inputs for rank
rank_tensors = RankTensors.make(cfgx, pgi) rank_tensors = RankTensors.make(cfgx, pgi)
# modular kernel out # modular kernel out
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
rank_tensors) rank_tensors)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
ref_out = reference_moe_impl(cfgx, weights, rank_tensors) ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) if config.quant_dtype == "nvfp4":
atol = 1e-1
rtol = 1e-1
else:
atol = 3e-2
rtol = 3e-2
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
format_result(verbose, config.describe())
except Exception as ex:
format_result(verbose, config.describe(), ex)
exceptions.append(ex)
if len(exceptions) > 0:
raise RuntimeError(
f"{len(exceptions)} of {count} tests failed in child process, "
f"rank={pgi.rank}.")
else:
print(f"{count} of {count} tests passed in child process, "
f"rank={pgi.rank}.")
def run(config: Config): def run(config: Config, verbose: bool):
assert config.is_valid() assert config.is_valid()
print(f"Testing config \n{config.describe()} ...")
weights: WeightTensors = WeightTensors.make(config) weights: WeightTensors = WeightTensors.make(config)
vllm_config, env_dict = config.make_env_data() vllm_config, env_dict = config.make_env_data()
parallel_launch_with_config(config.world_size, rank_worker, vllm_config, parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
env_dict, config, weights) env_dict, config, weights, verbose)
Ms = [32, 64] Ms = [32, 64]
Ks = [7168] # hidden sizes # hidden sizes, making this too large will cause fp4 tests to fail.
# Also needs to be a multiple of 1024 for deep_gemm.
Ks = [2048]
Ns = [2048] Ns = [2048]
TOPKs = [4, 1] TOPKs = [4, 1]
Es = [32] Es = [32]
@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
def is_nyi_config(config: Config) -> bool: def is_nyi_config(config: Config) -> bool:
# We know these configs to be legitimate. but still fail. # We know these configs to be legitimate. but still fail.
info = expert_info(config.fused_experts_type)
if (config.fused_experts_type in [ if info.needs_matching_quant:
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
TritonExperts, TritonOrDeepGemmExperts
]):
# The triton kernels expect both per-act-token-quant and # The triton kernels expect both per-act-token-quant and
# per-out-ch-quant or neither. # per-out-ch-quant or neither.
unsupported_quant_config = ((config.is_per_act_token_quant + unsupported_quant_config = ((config.is_per_act_token_quant +
config.is_per_out_ch_quant) == 1) config.is_per_out_ch_quant) == 1)
return unsupported_quant_config return unsupported_quant_config
# cutlass kernels dont support expert_maps yet. return not info.supports_expert_map
return config.fused_experts_type == CutlassExpertsFp8
@pytest.mark.parametrize("k", Ks) @pytest.mark.parametrize("k", Ks)
@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool:
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@meets_package_requirements @meets_multi_gpu_requirements
def test_modular_kernel_combinations_multigpu( def test_modular_kernel_combinations_multigpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config( config = Config(
Ms=Ms, Ms=Ms,
@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu(
fused_moe_chunk_size=fused_moe_chunk_size, fused_moe_chunk_size=fused_moe_chunk_size,
world_size=world_size, world_size=world_size,
) )
if not config.is_valid(): if not config.is_valid():
pytest.skip(f"Tests config {config} is not valid. Skipping ...") pytest.skip(f"Tests config {config} is not valid. Skipping ...")
if is_nyi_config(config): if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...") pytest.skip(f"Tests config {config} is nyi. Skipping ...")
print(f"{config.describe()}") verbosity = pytestconfig.getoption('verbose')
run(config) run(config, verbosity > 0)
@pytest.mark.parametrize("k", Ks) @pytest.mark.parametrize("k", Ks)
@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu(
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) @pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
@pytest.mark.parametrize("world_size", [1]) @pytest.mark.parametrize("world_size", [1])
@meets_package_requirements
def test_modular_kernel_combinations_singlegpu( def test_modular_kernel_combinations_singlegpu(
k: int, n: int, e: int, dtype: torch.dtype, k: int, n: int, e: int, dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: Optional[FusedMoEQuantConfig],
combination: tuple[mk.FusedMoEPrepareAndFinalize, combination: tuple[mk.FusedMoEPrepareAndFinalize,
mk.FusedMoEPermuteExpertsUnpermute], mk.FusedMoEPermuteExpertsUnpermute],
fused_moe_chunk_size: Optional[int], world_size: int): fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
config = Config( config = Config(
Ms=Ms, Ms=Ms,
K=k, K=k,
@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu(
if is_nyi_config(config): if is_nyi_config(config):
pytest.skip(f"Tests config {config} is nyi. Skipping ...") pytest.skip(f"Tests config {config} is nyi. Skipping ...")
run(config) verbosity = pytestconfig.getoption('verbose')
run(config, verbosity > 0)
if __name__ == '__main__': if __name__ == '__main__':
@ -211,4 +246,4 @@ if __name__ == '__main__':
args = parser.parse_args() args = parser.parse_args()
config = make_config(args) config = make_config(args)
run(config) run(config, True)

View File

@ -3,6 +3,7 @@
import pytest import pytest
import torch import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype) dequantize_nvfp4_to_dtype)
@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
VllmConfig(parallel_config=ParallelConfig( VllmConfig(parallel_config=ParallelConfig(
pipeline_parallel_size=1))): pipeline_parallel_size=1))):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16 quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
sf_w1_2n = round_up(2 * n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k),
device="cuda",
dtype=torch.float8_e4m3fn)
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n),
device="cuda",
dtype=torch.float8_e4m3fn)
w1_q = torch.empty((e, 2 * n, k // 2), (_, w1_q, w1_blockscale,
device="cuda", w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
dtype=torch.uint8) e,
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) n,
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) k,
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) in_dtype=dtype,
quant_dtype="nvfp4",
for expert in range(e): block_shape=None, # use quant_blocksize?
w1_amax = torch.abs(w1).max().to(torch.float32) per_act_token_quant=False,
w2_amax = torch.abs(w2).max().to(torch.float32) )
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
w1[expert], w1_gs[expert])
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
w2[expert], w2_gs[expert])
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, topk_weights, topk_ids, _ = fused_topk(a,
@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
assert w1_gs is not None
assert w2_gs is not None
assert w1_blockscale is not None
assert w2_blockscale is not None
cutlass_output = cutlass_moe_fp4( cutlass_output = cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs, a1_gscale=a1_gs,
@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
n=n, n=n,
k=k, k=k,
e=e, e=e,
device=a.device,
) )
# Reference check: # Reference check:
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a.flatten(), dim=-1)).to(torch.float32) torch.amax(a.flatten(), dim=-1)).to(torch.float32)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_scale_interleaved, a_scale_interleaved,
a_global_scale, a_global_scale,
@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
w1_blockscale[idx], w1_blockscale[idx],
w1_gs[idx], w1_gs[idx],
dtype=w1.dtype, dtype=dtype,
device=w1.device, device=w1_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
w2_blockscale[idx], w2_blockscale[idx],
w2_gs[idx], w2_gs[idx],
dtype=w2.dtype, dtype=dtype,
device=w2.device, device=w2_q.device,
block_size=quant_blocksize) block_size=quant_blocksize)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk) torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)

View File

@ -9,7 +9,8 @@ import torch
from tests.kernels.utils import torch_experts from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
@ -123,12 +124,8 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts, num_local_experts=num_local_experts,
num_dispatchers=num_dispatchers) num_dispatchers=num_dispatchers)
experts = CutlassExpertsFp8(num_local_experts, experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
out_dtype, out_dtype, per_act_token, per_out_ch)
per_act_token,
per_out_ch,
num_dispatchers=num_dispatchers,
use_batched_format=True)
fused_cutlass_experts = FusedMoEModularKernel( fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,

View File

@ -770,7 +770,7 @@ def test_pplx_moe_slow(
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10 a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights( (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,
@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
args = dict() args = dict()
if make_weights: if make_weights:
_, w1, w1_s, _, w2, w2_s = make_test_weights( (_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
e, e,
n, n,
k, k,

View File

@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Union
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8 from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts) BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
@ -169,28 +171,41 @@ def make_quantized_test_activations(
def moe_quantize_weights( def moe_quantize_weights(
w: torch.Tensor, w: torch.Tensor,
w_s: Optional[torch.Tensor], w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype], quant_dtype: Union[torch.dtype, str, None],
per_token_quant: bool, per_token_quant: bool,
block_shape: Optional[list[int]], block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
or quant_dtype == torch.int8), "only fp8/int8 supported" or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
w_gs = None
if block_shape is not None: if block_shape is not None:
assert not per_token_quant assert not per_token_quant
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape) w, w_s = per_block_cast_to_int8(w, block_shape)
else: elif quant_dtype == torch.float8_e4m3fn:
w, w_s = per_block_cast_to_fp8(w, block_shape) w, w_s = per_block_cast_to_fp8(w, block_shape)
elif quant_dtype == "nvfp4":
raise RuntimeError("blocked quantization not supported for nvfp4")
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
else: else:
if quant_dtype == torch.int8: if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant( w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant) w, w_s, use_per_token_if_dynamic=per_token_quant)
else: elif quant_dtype == torch.float8_e4m3fn:
w, w_s = ops.scaled_fp8_quant( w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant) w, w_s, use_per_token_if_dynamic=per_token_quant)
elif quant_dtype == "nvfp4":
assert not per_token_quant
w_amax = torch.abs(w).max().to(torch.float32)
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
w, w_s = ops.scaled_fp4_quant(w, w_gs)
else:
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
return w, w_s return w, w_s, w_gs
def make_test_weight( def make_test_weight(
@ -198,21 +213,26 @@ def make_test_weight(
rows: int, rows: int,
cols: int, cols: int,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15 w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
w_gs = None
if quant_dtype is not None: if quant_dtype is not None:
w_l = [None] * e w_l = [None] * e
w_s_l = [None] * e w_s_l = [None] * e
w_gs_l = [None] * e
for idx in range(e): for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights( w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape) w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l) w = torch.stack(w_l)
w_s = torch.stack(w_s_l) w_s = torch.stack(w_s_l)
if e > 0 and w_gs_l[0] is not None:
w_gs = torch.stack(w_gs_l)
if w_s.ndim == 2: if w_s.ndim == 2:
assert w_s.shape[-1] == 1 assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1) w_s = w_s.view(-1, 1, 1)
@ -225,8 +245,9 @@ def make_test_weight(
else: else:
w = w_16 w = w_16
w_s = None w_s = None
w_gs = None
return w_16, w, w_s return w_16, w, w_s, w_gs
def make_test_weights( def make_test_weights(
@ -234,14 +255,30 @@ def make_test_weights(
n: int, n: int,
k: int, k: int,
in_dtype: torch.dtype = torch.bfloat16, in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None, quant_dtype: Union[torch.dtype, str, None] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, ) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
torch.Tensor, Optional[torch.Tensor]]: Optional[torch.Tensor]],
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]]:
return ( return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape, make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant), per_act_token_quant),
) )
def per_token_cast_to_fp8(
x: torch.Tensor,
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (block_size - (n % block_size)) % block_size
x = torch.nn.functional.pad(x,
(0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, block_size)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)

View File

@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
# we initialize the all2all manager used in expert parallel. # we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1 use_ep = config.parallel_config.data_parallel_size > 1
self.use_all2all = "ep" in unique_name and use_ep self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_manager: Optional[All2AllManagerBase] = None self.all2all_manager: Optional[All2AllManagerBase] = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
""" """
Prepare the communication buffer for the model. Prepare the communication buffer for the model.
""" """
if not self.use_all2all: if not self.is_ep_communicator:
return return
moe_modules = [ moe_modules = [
@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE" if module.__class__.__name__ == "FusedMoE"
] ]
for module in moe_modules: for module in moe_modules:
module.quant_method.init_prepare_finalize(module.moe_config) module.quant_method.init_prepare_finalize()
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,

View File

@ -49,7 +49,8 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts) BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8) CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4,
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts) DeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
@ -69,6 +70,7 @@ if HAS_TRITON:
"cutlass_moe_fp8", "cutlass_moe_fp8",
"cutlass_moe_fp4", "cutlass_moe_fp4",
"CutlassExpertsFp8", "CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",
"TritonExperts", "TritonExperts",
"BatchedTritonExperts", "BatchedTritonExperts",
"DeepGemmExperts", "DeepGemmExperts",

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional
import torch import torch
@ -254,18 +254,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K) output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, self,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, output: torch.Tensor,
expert_map: Optional[torch.Tensor], hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor], w1: torch.Tensor,
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2: torch.Tensor,
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, topk_ids: torch.Tensor,
workspace2: torch.Tensor, activation: str,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], global_num_experts: int,
apply_router_weight_on_input: bool, expert_map: Optional[torch.Tensor],
extra_expert_args: Optional[dict[str, Any]]): w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional
import torch import torch
@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a, aq, M, N, K, topk, global_num_experts, local_num_experts, a, aq, M, N, K, topk, global_num_experts, local_num_experts,
expert_tokens_metadata) expert_tokens_metadata)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, self,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, output: torch.Tensor,
expert_map: Optional[torch.Tensor], hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor], w1: torch.Tensor,
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2: torch.Tensor,
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, topk_ids: torch.Tensor,
workspace2: torch.Tensor, activation: str,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], global_num_experts: int,
apply_router_weight_on_input: bool, expert_map: Optional[torch.Tensor],
extra_expert_args: Optional[dict[str, Any]]): w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
experts = (self.batched_deep_gemm_experts experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts) if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None assert experts is not None
@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation, global_num_experts, expert_map, w1_scale, activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta, workspace2, expert_tokens_meta,
apply_router_weight_on_input, extra_expert_args) apply_router_weight_on_input)

View File

@ -45,7 +45,6 @@ def get_quant_config_weight_quant(
return _get_quant_config_quantization_args(quant_config, "weights") return _get_quant_config_quantization_args(quant_config, "weights")
# TODO (bnell): use scalar_type instead of bools?
def get_config_quant_dtype( def get_config_quant_dtype(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a8: bool, use_int8_w8a8: bool,
@ -65,7 +64,8 @@ def get_config_quant_dtype(
@dataclass @dataclass
class FusedMoEQuantConfig: class FusedMoEQuantConfig:
# The post quantization activation type. # The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None # TODO (bnell): use scalar_type instead of Union.
quant_dtype: Union[torch.dtype, str, None] = None
per_act_token_quant: bool = False per_act_token_quant: bool = False
per_out_ch_quant: bool = False per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None block_shape: Optional[list[int]] = None
@ -141,6 +141,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
use_mxfp4_w4a4,
] ]
]) <= 1, "Quantization flags are mutually exclusive." ]) <= 1, "Quantization flags are mutually exclusive."
@ -334,7 +335,7 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0 assert self.max_num_tokens > 0
@property @property
def quant_dtype(self) -> Optional[torch.dtype]: def quant_dtype(self) -> Union[torch.dtype, str, None]:
if self.quant_config is not None: if self.quant_config is not None:
return self.quant_config.quant_dtype return self.quant_config.quant_dtype
else: else:
@ -429,7 +430,7 @@ class FusedMoEConfig:
block_shape = None block_shape = None
per_act_token_quant = False per_act_token_quant = False
per_out_ch_quant = False per_out_ch_quant = False
quant_dtype: Optional[torch.dtype] = None quant_dtype: Union[torch.dtype, str, None] = None
input_quant = get_quant_config_input_quant(quant_config) input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config) weight_quant = get_quant_config_weight_quant(quant_config)
@ -453,7 +454,7 @@ class FusedMoEConfig:
ModelOptNvFp4Config) ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config, if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config): ModelOptNvFp4Config):
quant_dtype = torch.uint8 quant_dtype = "nvfp4"
if weight_quant is not None: if weight_quant is not None:
per_out_ch_quant = ( per_out_ch_quant = (

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" CUTLASS based Fused MoE kernels.""" """ CUTLASS based Fused MoE kernels."""
from typing import Any, Callable, Optional from typing import Callable, Optional
import torch import torch
@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import ( from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP) MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize, _fp8_quantize,
_resize_cache, _resize_cache)
extract_required_args)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
@ -213,19 +212,14 @@ def run_cutlass_moe_fp8(
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched? class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
# maybe remove need for passing aq to workspace_shapes
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
max_experts_per_worker: int,
out_dtype: Optional[torch.dtype], out_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
per_out_ch_quant: bool, per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
): ):
super().__init__( super().__init__(
FusedMoEQuantConfig( FusedMoEQuantConfig(
@ -234,33 +228,84 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
per_out_ch_quant=per_out_ch_quant, per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape, block_shape=block_shape,
)) ))
assert max_experts_per_worker > 0
assert not use_batched_format or num_dispatchers is not None
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.use_batched_format = use_batched_format
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i)
use_batched_format = self.activation_formats[
0] == mk.FusedMoEActivationFormat.BatchedExperts
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
use_batched_format)
class CutlassExpertsFp8(CutlassExpertsFp8Base):
def __init__(
self,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
@property @property
def activation_formats( def activation_formats(
self self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format: return (mk.FusedMoEActivationFormat.Standard,
return (mk.FusedMoEActivationFormat.BatchedExperts, mk.FusedMoEActivationFormat.Standard)
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool: def supports_chunking(self) -> bool:
return not self.use_batched_format return True
def supports_expert_map(self) -> bool: def supports_expert_map(self) -> bool:
return not self.use_batched_format return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
@ -274,54 +319,69 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts: int, local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = () workspace1 = (M * topk, max(N, K))
workspace2: tuple[int, ...] = () workspace2 = (M * topk, N // 2)
output: tuple[int, ...] = () output = (M * topk, K)
if self.use_batched_format:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
(N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(N, K))
workspace2 = (M * topk, N // 2)
output = (M * topk, K)
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
expert_num_tokens = None class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens
activation_callable = lambda o, i: self.activation(activation, o, i) def __init__(
self,
max_experts_per_worker: int,
num_dispatchers: int,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
):
super().__init__(
out_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
in_dtype = hidden_states.dtype @property
run_cutlass_moe_fp8( def activation_formats(
output, hidden_states, w1, w2, topk_ids, activation_callable, self
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale, ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
a2_scale, workspace13, workspace2, expert_num_tokens, return (mk.FusedMoEActivationFormat.BatchedExperts,
self.out_dtype if self.out_dtype is not None else in_dtype, mk.FusedMoEActivationFormat.BatchedExperts)
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format) def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
padded_M = aq.size(1)
num_dp = self.num_dispatchers
assert num_dp is not None
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def cutlass_moe_fp8( def cutlass_moe_fp8(
@ -387,11 +447,9 @@ def cutlass_moe_fp8(
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8( CutlassExpertsFp8(
max_experts_per_worker=num_experts,
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=per_act_token, per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch, per_out_ch_quant=per_out_ch,
use_batched_format=False,
), ),
) )
@ -476,8 +534,9 @@ def run_cutlass_moe_fp4(
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", assert (e_w1 == e_w2
" between weights.") and e_w1 == e), ("Number of experts must match",
f" between weights. {e_w1}, {e_w2}, {e}")
assert (k_a == half_k_w1 * 2 assert (k_a == half_k_w1 * 2
and k == k_w2), ("Hidden size mismatch between a, w1 and w2") and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
g1_alphas: torch.Tensor,
g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
max_experts_per_worker: int, max_experts_per_worker: int,
out_dtype: torch.dtype, out_dtype: torch.dtype,
per_act_token_quant: bool, per_act_token_quant: bool,
@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
use_batched_format: bool = False, use_batched_format: bool = False,
): ):
super().__init__( super().__init__(
# NVFP4 requires two levels of quantization, which involves
# computing some scaling factors dynamically. This makes it
# incompatible with the typical prepare -> MoE -> finalize
# pipeline. Move the quantization logic into the MoE body.
FusedMoEQuantConfig( FusedMoEQuantConfig(
quant_dtype=torch.uint8, quant_dtype=None, # skip quantization in prepare/finalize
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant, per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape, block_shape=block_shape,
@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
self.out_dtype = out_dtype self.out_dtype = out_dtype
self.use_batched_format = use_batched_format self.use_batched_format = use_batched_format
# TODO(bnell): put this stuff into quant config?
self.g1_alphas = g1_alphas
self.g2_alphas = g2_alphas
self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
@property @property
def activation_formats( def activation_formats(
self self
@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceNoOP()
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
return (workspace1, workspace2, output, return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype) self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, self,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, output: torch.Tensor,
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, hidden_states: torch.Tensor,
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], w1: torch.Tensor,
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], w2: torch.Tensor,
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], topk_weights: torch.Tensor,
workspace2: Optional[torch.Tensor], topk_ids: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], activation: str,
apply_router_weight_on_input: bool, global_num_experts: int,
extra_expert_args: Optional[dict[str, Any]]): expert_map: Optional[torch.Tensor],
required_keys = [ w1_scale: torch.Tensor,
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", w2_scale: torch.Tensor,
"e", "device" w1_zp: Optional[torch.Tensor],
] w2_zp: Optional[torch.Tensor],
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, a1q_scale: Optional[torch.Tensor],
device) = extract_required_args(extra_expert_args, required_keys) a2_scale: torch.Tensor,
workspace13: Optional[torch.Tensor],
workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
n = w2.shape[2] * 2
run_cutlass_moe_fp4( run_cutlass_moe_fp4(
output=output, output=output,
a=hidden_states, a=hidden_states,
a1_gscale=a1_gscale, a1_gscale=self.a1_gscale,
w1_fp4=w1, w1_fp4=w1,
w1_blockscale=w1_scale, w1_blockscale=w1_scale,
w1_alphas=g1_alphas, w1_alphas=self.g1_alphas,
a2_gscale=a2_gscale, a2_gscale=self.a2_gscale,
w2_fp4=w2, w2_fp4=w2,
w2_blockscale=w2_scale, w2_blockscale=w2_scale,
w2_alphas=g2_alphas, w2_alphas=self.g2_alphas,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
workspace13=workspace13, workspace13=workspace13,
@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
n=n, n=n,
k=k, k=k,
e=e, e=e,
device=device, device=hidden_states.device,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
@ -677,7 +757,6 @@ def cutlass_moe_fp4(
n: int, n: int,
k: int, k: int,
e: int, e: int,
device: torch.device,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False) -> torch.Tensor: apply_router_weight_on_input: bool = False) -> torch.Tensor:
assert expert_map is None, ("Expert Parallelism / expert_map " assert expert_map is None, ("Expert Parallelism / expert_map "
@ -686,6 +765,10 @@ def cutlass_moe_fp4(
fn = mk.FusedMoEModularKernel( fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(), MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp4( CutlassExpertsFp4(
g1_alphas,
g2_alphas,
a1_gscale,
a2_gscale,
max_experts_per_worker=e, max_experts_per_worker=e,
out_dtype=a.dtype, out_dtype=a.dtype,
per_act_token_quant=False, per_act_token_quant=False,
@ -693,29 +776,7 @@ def cutlass_moe_fp4(
use_batched_format=False, use_batched_format=False,
), ),
) )
extra_expert_args = {
'g1_alphas': g1_alphas,
'g2_alphas': g2_alphas,
'a1_gscale': a1_gscale,
'a2_gscale': a2_gscale,
'm': m,
'n': n,
'k': k,
'e': e,
'device': device,
}
# NVFP4 requires two levels of quantization, which involves computing some
# scaling factors dynamically. This makes it incompatible with the typical
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
# MoE body.
extra_prepare_args = {
'skip_quant': True,
}
# Similar reason as above.
extra_finalize_args = {
'skip_weight_reduce': True,
}
return fn( return fn(
hidden_states=a, hidden_states=a,
w1=w1_fp4, w1=w1_fp4,
@ -731,9 +792,6 @@ def cutlass_moe_fp4(
a1_scale=None, a1_scale=None,
a2_scale=None, a2_scale=None,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
) )
@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts(
k = w1_q.size(1) k = w1_q.size(1)
n = w2_q.size(1) n = w2_q.size(1)
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device="cuda")
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device="cuda")
topk = topk_ids.size(1) topk = topk_ids.size(1)
a_q, a1_scale = _fp8_quantize(a, a_q, a1_scale = _fp8_quantize(a,
@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts(
block_shape=[128, 128]) block_shape=[128, 128])
device = a_q.device device = a_q.device
expert_offsets = torch.empty((num_experts + 1, ),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from typing import Any, Optional from typing import Optional
import torch import torch
from tqdm import tqdm from tqdm import tqdm
@ -230,7 +230,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
assert self.block_shape is not None assert self.block_shape is not None
assert a1q_scale is not None assert a1q_scale is not None

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional
import deep_ep import deep_ep
import torch import torch
@ -127,12 +127,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights) expert_topk_weights)
def prepare( def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], self,
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, a1: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int, a1_scale: Optional[torch.Tensor],
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
@ -187,11 +191,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
weight_and_reduce_impl: mk.TopKWeightAndReduce, fused_expert_output: torch.Tensor,
extra_finalize_args: Optional[dict[str, Any]]) -> None: topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert self.handle is not None assert self.handle is not None

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union from typing import Optional, Union
import deep_ep import deep_ep
import torch import torch
@ -77,7 +77,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype], quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]], block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -111,12 +111,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales return x, x_scales
def prepare( def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], self,
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, a1: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int, a1_scale: Optional[torch.Tensor],
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
@ -162,11 +166,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return (expert_x, expert_x_scale, expert_tokens_meta, None, None) return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
weight_and_reduce_impl: mk.TopKWeightAndReduce, fused_expert_output: torch.Tensor,
extra_finalize_args: Optional[dict[str, Any]]) -> None: topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional, Union
import torch import torch
@ -8,8 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
has_flashinfer_cutlass_fused_moe) has_flashinfer_cutlass_fused_moe)
@ -20,7 +19,7 @@ def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor) -> bool: w2: torch.Tensor) -> bool:
""" """
Check if the given problem size is supported by the FlashInfer CUTLASS MoE Check if the given problem size is supported by the FlashInfer CUTLASS MoE
kernel. kernel.
""" """
if not has_flashinfer_cutlass_fused_moe(): if not has_flashinfer_cutlass_fused_moe():
@ -43,31 +42,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__( def __init__(
self, self,
use_nvfp4_w4a4: bool = False, g1_alphas: torch.Tensor,
use_fp8_w8a8: bool = False, g2_alphas: torch.Tensor,
use_dp: bool = False, a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
out_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
ep_rank: int = 0, ep_rank: int = 0,
ep_size: int = 1, ep_size: int = 1,
tp_rank: int = 0, tp_rank: int = 0,
tp_size: int = 1, tp_size: int = 1,
num_dispatchers: Optional[int] = None,
use_batched_format: bool = False,
): ):
super().__init__( super().__init__(
FusedMoEQuantConfig( FusedMoEQuantConfig(
quant_dtype=torch.uint8, quant_dtype=quant_dtype,
per_act_token_quant=False, per_act_token_quant=False,
block_shape=None, block_shape=None,
)) ))
self.use_nvfp4_w4a4 = use_nvfp4_w4a4 assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
self.use_fp8_w8a8 = use_fp8_w8a8 "currently supported.")
self.ep_rank = ep_rank self.ep_rank = ep_rank
self.ep_size = ep_size self.ep_size = ep_size
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.use_dp = use_dp self.g1_alphas = g1_alphas
assert not use_batched_format or num_dispatchers is not None self.g2_alphas = g2_alphas
self.num_dispatchers = num_dispatchers self.a1_gscale = a1_gscale
self.a2_gscale = a2_gscale
self.out_dtype = out_dtype
@property @property
def activation_formats( def activation_formats(
@ -84,8 +86,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
return True return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl. return TopKWeightAndReduceNoOP()
return TopKWeightAndReduceDelegate()
def workspace_shapes( def workspace_shapes(
self, self,
@ -117,8 +118,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
- Note: in order for activation chunking to work, the first dimension - Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens. of each tuple must be the number of tokens.
""" """
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
aq_m, aq_n = aq.shape aq_m, aq_n = aq.shape
workspace2 = () workspace2 = ()
output_shape = (aq_m, aq_n * 2) output_shape = (aq_m, aq_n * 2)
@ -149,21 +148,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: Optional[torch.Tensor], workspace2: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: Optional[bool], apply_router_weight_on_input: Optional[bool],
extra_expert_args: Optional[dict[str, Any]],
): ):
assert extra_expert_args is not None, \
"extra_expert_args must be provided"
required_keys = [
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
]
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
extract_required_args(extra_expert_args, required_keys))
# Flashinfer CUTLASS kernel takes scalar global scales, # Flashinfer CUTLASS kernel takes scalar global scales,
# min because inv_scale. # min because inv_scale.
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
"currently supported.")
# Ensure w1_scale and w2_scale are not None before calling view # Ensure w1_scale and w2_scale are not None before calling view
assert w1_scale is not None and w2_scale is not None, ( assert w1_scale is not None and w2_scale is not None, (
@ -171,12 +158,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
"be None for FlashInferExperts") "be None for FlashInferExperts")
quant_scales = [ quant_scales = [
a1_gscale, self.a1_gscale,
w1_scale.view(torch.int32), w1_scale.view(torch.int32),
g1_alphas, self.g1_alphas,
a2_gscale, self.a2_gscale,
w2_scale.view(torch.int32), w2_scale.view(torch.int32),
g2_alphas, self.g2_alphas,
] ]
_ = flashinfer_cutlass_fused_moe( _ = flashinfer_cutlass_fused_moe(
input=hidden_states, input=hidden_states,
@ -185,7 +172,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4 # FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights=w1.view(torch.long), fc1_expert_weights=w1.view(torch.long),
fc2_expert_weights=w2.view(torch.long), fc2_expert_weights=w2.view(torch.long),
output_dtype=out_dtype, output_dtype=self.out_dtype,
quant_scales=quant_scales, quant_scales=quant_scales,
input_sf=a1q_scale, input_sf=a1q_scale,
tp_size=self.tp_size, tp_size=self.tp_size,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional
import torch import torch
@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import ( from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input) moe_kernel_quantize_input)
from vllm.utils.flashinfer import nvfp4_block_scale_interleave from vllm.utils.flashinfer import nvfp4_block_scale_interleave
@ -21,16 +21,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__( def __init__(
self, self,
quant_dtype: Optional[torch.dtype] = None, use_dp: bool,
per_channel_quant: bool = False, a1_gscale: Optional[torch.Tensor],
block_shape: Optional[list[int]] = None,
num_dispatchers: int = 1, num_dispatchers: int = 1,
): ):
super().__init__() super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.use_dp = use_dp
self.a1_gscale = a1_gscale
self.local_tokens = None
@property @property
def activation_format(self) -> mk.FusedMoEActivationFormat: def activation_format(self) -> mk.FusedMoEActivationFormat:
@ -55,10 +54,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int, num_experts: int,
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[torch.Tensor, Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]: Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
@ -67,22 +67,22 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1" "apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
(a1_gscale, use_dp, local_tokens) = extract_required_args(
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
a1_gscale, self.a1_gscale,
quant_config.quant_dtype, quant_config.quant_dtype,
self.per_channel_quant, quant_config.per_act_token_quant,
self.block_shape, quant_config.block_shape,
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication # Swizzling after communication
is_fp4_scale_swizzled=not self.use_dp,
) )
if use_dp: if self.use_dp:
topk_weights, topk_ids, a1q, a1q_scale = \ topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 get_dp_group().all_gatherv(
dim=0, [topk_weights, topk_ids, a1q, a1q_scale],
sizes=get_local_sizes()) dim=0,
sizes=get_local_sizes(),
)
a1_m, a1_n = a1q.shape a1_m, a1_n = a1q.shape
a1q_scale = nvfp4_block_scale_interleave(a1q_scale) a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
@ -91,13 +91,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce, weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
extra_finalize_args: Optional[dict[str, Any]]) -> None:
(use_dp, if self.use_dp:
local_tokens) = extract_required_args(extra_finalize_args,
['use_dp', 'local_tokens'])
if use_dp:
fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output, dim=0, sizes=get_local_sizes()) fused_expert_output, dim=0, sizes=get_local_sizes())
output.copy_(fused_expert_output) output.copy_(fused_expert_output)

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel.""" """Fused batched MoE kernel."""
from typing import Any, Optional from typing import Optional
import torch import torch
@ -496,12 +496,16 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_ return self.num_dispatchers_
def prepare( def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], self,
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, a1: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int, a1_scale: Optional[torch.Tensor],
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
@ -590,11 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return b_a1, b_a1_scale, expert_tokens_meta, None, None return b_a1, b_a1_scale, expert_tokens_meta, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
weight_and_reduce_impl: mk.TopKWeightAndReduce, fused_expert_output: torch.Tensor,
extra_finalize_args: Optional[dict[str, Any]]) -> None: topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
weight_and_reduce_impl.apply( weight_and_reduce_impl.apply(
@ -688,18 +696,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else: else:
return t.to(f32) * group_broadcast(scale, t.shape) return t.to(f32) * group_broadcast(scale, t.shape)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, self,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, output: torch.Tensor,
expert_map: Optional[torch.Tensor], hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor], w1: torch.Tensor,
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2: torch.Tensor,
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, topk_ids: torch.Tensor,
workspace2: torch.Tensor, activation: str,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], global_num_experts: int,
apply_router_weight_on_input: bool, expert_map: Optional[torch.Tensor],
extra_expert_args: Optional[dict[str, Any]]): w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
assert hidden_states.dim() == 3 assert hidden_states.dim() == 3
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens expert_num_tokens = expert_tokens_meta.expert_num_tokens
@ -894,18 +912,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dp, K) output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype) return (workspace13, workspace2, output, a.dtype)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, self,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, output: torch.Tensor,
expert_map: Optional[torch.Tensor], hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor], w1: torch.Tensor,
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2: torch.Tensor,
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, topk_ids: torch.Tensor,
workspace2: torch.Tensor, activation: str,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], global_num_experts: int,
apply_router_weight_on_input: bool, expert_map: Optional[torch.Tensor],
extra_expert_args: Optional[dict[str, Any]]): w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), ( assert hidden_states.size(-1) // 2 == w1.size(2), (

View File

@ -1394,9 +1394,9 @@ def fused_experts(hidden_states: torch.Tensor,
# E8M0 scale, which means we requantize the weight and input to the specific # E8M0 scale, which means we requantize the weight and input to the specific
# scale. Fallen back to cutlass or triton for some cases would cause # scale. Fallen back to cutlass or triton for some cases would cause
# accuracy issue. # accuracy issue.
should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used( if (allow_deep_gemm and use_fp8_w8a8
) or _valid_deep_gemm(hidden_states, w1, w2) and (is_blackwell_deep_gemm_e8m0_used()
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): or _valid_deep_gemm(hidden_states, w1, w2))):
assert apply_router_weight_on_input is False assert apply_router_weight_on_input is False
assert is_act_and_mul, ( assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.") "DeepGemm only supports is_act_and_mul=True for now.")
@ -1905,7 +1905,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
# Check constraints. # Check constraints.
if self.use_int4_w4a16: if self.use_int4_w4a16:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
@ -8,7 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate) TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils import has_triton_kernels from vllm.utils import has_triton_kernels
logger = init_logger(__name__) logger = init_logger(__name__)
@ -160,12 +159,16 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers: int, num_dispatchers: int,
w1_precision: "PrecisionConfig", w1_precision: "PrecisionConfig",
w2_precision: "PrecisionConfig", w2_precision: "PrecisionConfig",
w1_bias: Optional[torch.Tensor],
w2_bias: Optional[torch.Tensor],
): ):
super().__init__(quant_config) super().__init__(quant_config)
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
self.w1_precision = w1_precision self.w1_precision = w1_precision
self.w2_precision = w2_precision self.w2_precision = w2_precision
self.w1_bias = w1_bias
self.w2_bias = w2_bias
@property @property
def activation_formats( def activation_formats(
@ -219,11 +222,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
["w1_bias", "w2_bias"]))
return triton_kernel_fused_experts( return triton_kernel_fused_experts(
output, output,
hidden_states, hidden_states,
@ -240,8 +239,8 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_map=expert_map, expert_map=expert_map,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
w1_bias=w1_bias, w1_bias=self.w1_bias,
w2_bias=w2_bias, w2_bias=self.w2_bias,
w1_precision=self.w1_precision, w1_precision=self.w1_precision,
w2_precision=self.w2_precision, w2_precision=self.w2_precision,
a1_scale=a1q_scale, a1_scale=a1q_scale,

View File

@ -37,7 +37,6 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
round_up) round_up)
from vllm.utils.flashinfer import has_flashinfer
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts
@ -49,9 +48,6 @@ if current_platform.is_cuda_alike():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize) DeepEPLLPrepareAndFinalize)
if has_flashinfer():
from .flashinfer_cutlass_prepare_finalize import (
FlashInferCutlassMoEPrepareAndFinalize)
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -80,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum):
class FusedMoEMethodBase(QuantizeMethodBase): class FusedMoEMethodBase(QuantizeMethodBase):
moe: FusedMoEConfig # TODO(bnell): also pass quant_config?
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.topk_indices_dtype = None
@abstractmethod @abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -99,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return False return False
@staticmethod @staticmethod
def maybe_make_prepare_finalize( def _maybe_make_prepare_finalize(
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
all2all_manager = get_ep_group().device_communicator.all2all_manager all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None assert all2all_manager is not None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_flashinfer_cutlass_kernels: assert not moe.use_flashinfer_cutlass_kernels, \
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( "Must be created in modelopt.py"
quant_dtype=moe.quant_dtype, )
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens, moe.max_num_tokens,
@ -188,14 +189,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
return prepare_finalize return prepare_finalize
def init_prepare_finalize(self, moe: FusedMoEConfig): def maybe_make_prepare_finalize(
self.moe = moe self,
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( moe: FusedMoEConfig,
self.moe) ) -> Optional[FusedMoEPrepareAndFinalize]:
if moe.moe_parallel_config.use_all2all_kernels:
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
else:
return None
def init_prepare_finalize(self):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
self.topk_indices_dtype = None
if prepare_finalize is not None: if prepare_finalize is not None:
logger.debug("%s", prepare_finalize.__class__.__name__) logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
self, id(self))
assert self.topk_indices_dtype is None
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe) experts = self.select_gemm_impl(prepare_finalize, self.moe)
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
@ -214,12 +226,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
f"{self.__class__.__name__} must select appropriate gemm " f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize") "implementation based on the prepare_finalize")
def maybe_swap_experts_impl(
self,
moe_parallel_config: FusedMoEParallelConfig,
):
pass
@abstractmethod @abstractmethod
def apply( def apply(
self, self,
@ -251,10 +257,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__() super().__init__(moe)
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
self.has_bias = self.moe.has_bias self.has_bias = self.moe.has_bias
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
@ -266,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl( def select_gemm_impl(
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig, moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format == if (prepare_finalize.activation_format ==
@ -474,9 +478,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_map=expert_map, expert_map=expert_map,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
else: elif self.fused_experts is not None:
# add w1_bias/w2_bias to kwargs if they exist if self.has_bias:
kwargs = dict( raise ValueError(
"FusedMoEModularKernel does not support bias.")
return self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
@ -488,17 +494,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
) )
if isinstance(self.fused_experts, else:
FusedMoEModularKernel) and self.has_bias: assert fused_experts is not None
raise ValueError( return fused_experts(
"FusedMoEModularKernel does not support bias.") hidden_states=x,
if self.has_bias: w1=layer.w13_weight,
kwargs.update({ w2=layer.w2_weight,
"w1_bias": getattr(layer, "w13_bias", None), w1_bias=layer.w13_bias if self.has_bias else None,
"w2_bias": getattr(layer, "w2_bias", None), w2_bias=layer.w2_bias if self.has_bias else None,
}) topk_weights=topk_weights,
topk_ids=topk_ids,
return self.fused_experts(**kwargs) inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
def forward_cpu( def forward_cpu(
self, self,
@ -868,8 +879,6 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
if isinstance(self.quant_method, FusedMoEMethodBase):
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
# Chunked all2all staging tensor # Chunked all2all staging tensor
self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_hidden_states: Optional[torch.Tensor] = None

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Any, Optional, final from typing import Optional, final
import torch import torch
@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC):
@abstractmethod @abstractmethod
def prepare( def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], self,
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, a1: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int, a1_scale: Optional[torch.Tensor],
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]] ) -> tuple[
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor,
Optional[ExpertTokensMetadata], Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
""" """
Perform any quantization (and/or) dispatching needed Perform any quantization (and/or) dispatching needed
for this kernel. for this kernel.
@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
weight_and_reduce_impl: TopKWeightAndReduce, fused_expert_output: torch.Tensor,
extra_finalize_args: Optional[dict[str, Any]]) -> None: topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: TopKWeightAndReduce,
) -> None:
""" """
Perform any combine plus apply weights and perform a reduction on the Perform any combine plus apply weights and perform a reduction on the
fused experts output. fused experts output.
@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
): ):
""" """
This function computes the intermediate result of a Mixture of Experts This function computes the intermediate result of a Mixture of Experts
@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.activation_formats[0]}") f"{fused_experts.activation_formats[0]}")
def _do_fused_experts( def _do_fused_experts(
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, self,
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, fused_out: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor, a1: torch.Tensor,
activation: str, global_num_experts: int, local_num_experts: int, a1q: torch.Tensor,
expert_map: Optional[torch.Tensor], w1: torch.Tensor,
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], w2: torch.Tensor,
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], topk_weights: torch.Tensor,
a1q_scale: Optional[torch.Tensor], topk_ids: torch.Tensor,
a2_scale: Optional[torch.Tensor], activation: str,
expert_tokens_meta: Optional[ExpertTokensMetadata], global_num_experts: int,
apply_router_weight_on_input: bool, local_num_experts: int,
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args) )
return fused_out return fused_out
@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata], expert_tokens_meta: Optional[ExpertTokensMetadata],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
) -> torch.Tensor: ) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module):
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE) num_chunks = cdiv(M, CHUNK_SIZE)
# TODO(bnell): get rid of one level here, update slice functions
# to nops on num_chunks==1
if not self.fused_experts.supports_chunking() or num_chunks == 1: if not self.fused_experts.supports_chunking() or num_chunks == 1:
return self._do_fused_experts( return self._do_fused_experts(
fused_out=None, fused_out=None,
@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args) )
# Chunking required case # Chunking required case
assert num_chunks > 1 assert num_chunks > 1
@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=c_expert_num_tokens, expert_num_tokens=c_expert_num_tokens,
expert_num_tokens_cpu=c_expert_num_tokens_cpu) expert_num_tokens_cpu=c_expert_num_tokens_cpu)
m = None
if extra_expert_args is not None and 'm' in extra_expert_args:
m = extra_expert_args.get('m')
if extra_expert_args is not None:
chunked_extra_expert_args = extra_expert_args
else:
chunked_extra_expert_args = {}
for chunk_idx in range(num_chunks): for chunk_idx in range(num_chunks):
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
slice_input_tensors(chunk_idx)) slice_input_tensors(chunk_idx))
@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_tokens_meta, c_topk_ids, local_num_experts, expert_tokens_meta, c_topk_ids, local_num_experts,
expert_map) expert_map)
s = chunk_idx * CHUNK_SIZE
e = min(s + CHUNK_SIZE, M)
if m is not None:
chunked_extra_expert_args['m'] = e - s
self._do_fused_experts( self._do_fused_experts(
fused_out=slice_output_tensor(chunk_idx), fused_out=slice_output_tensor(chunk_idx),
a1=a1, a1=a1,
@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=c_a2_scale, a2_scale=c_a2_scale,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=chunked_extra_expert_args) )
return fused_out return fused_out
@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
extra_expert_args: Optional[dict] = None,
extra_prepare_args: Optional[dict] = None,
extra_finalize_args: Optional[dict] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module):
- apply_router_weight_on_input (bool): When true, the topk weights are - apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is applied directly on the inputs. This is only applicable when topk is
1. 1.
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
fused_experts.apply.
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
to prepare.
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
to finalize.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map, expert_map,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.quant_config, self.fused_experts.quant_config,
extra_prepare_args,
) )
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module):
a2_scale=a2_scale, a2_scale=a2_scale,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args) )
self.prepare_finalize.finalize( self.prepare_finalize.finalize(
output, fused_out, topk_weights, topk_ids, output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input, apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
extra_finalize_args) )
return output return output

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional, Union
import pplx_kernels as pplx import pplx_kernels as pplx
import torch import torch
@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes(
max_num_tokens: int, max_num_tokens: int,
hidden_dim: int, hidden_dim: int,
in_dtype: torch.dtype, in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype], quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]], block_shape: Optional[list[int]],
): ):
@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes(
# ceil_div(hidden_dim, block_size) * sizeof(float32) # ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment) # For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None: if quant_dtype is not None:
assert isinstance(quant_dtype, torch.dtype)
assert quant_dtype.itemsize == 1 assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize elem_size = torch.float32.itemsize
@ -89,12 +90,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.num_dispatchers_ return self.num_dispatchers_
def prepare( def prepare(
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], self,
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, a1: torch.Tensor,
topk_ids: torch.Tensor, num_experts: int, a1_scale: Optional[torch.Tensor],
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]]
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
@ -213,11 +218,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None return expert_x, expert_x_scale, expert_tokens_meta, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
weight_and_reduce_impl: mk.TopKWeightAndReduce, fused_expert_output: torch.Tensor,
extra_finalize_args: Optional[dict[str, Any]]) -> None: topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
assert isinstance( assert isinstance(
weight_and_reduce_impl, TopKWeightAndReduceDelegate weight_and_reduce_impl, TopKWeightAndReduceDelegate
), ("Weight application and reduction happens in the combine kernel.") ), ("Weight application and reduction happens in the combine kernel.")

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional
import torch import torch
@ -38,7 +38,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
extra_prepare_args: Optional[dict[str, Any]],
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]: Optional[torch.Tensor]]:
@ -50,32 +49,26 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1" "apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype)) a1.mul_(topk_weights.to(a1.dtype))
if (extra_prepare_args is not None
and extra_prepare_args.get("skip_quant", True)):
# Skip quantization if explicitly requested
return a1, None, None, None, None
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype, a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None return a1q, a1q_scale, None, None, None
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, def finalize(
topk_weights: torch.Tensor, topk_ids: torch.Tensor, self,
apply_router_weight_on_input: bool, output: torch.Tensor,
weight_and_reduce_impl: mk.TopKWeightAndReduce, fused_expert_output: torch.Tensor,
extra_finalize_args: Optional[dict[str, Any]]) -> None: topk_weights: torch.Tensor,
if (extra_finalize_args is not None topk_ids: torch.Tensor,
and extra_finalize_args.get("skip_weight_reduce", True)): apply_router_weight_on_input: bool,
assert output.shape == fused_expert_output.shape weight_and_reduce_impl: mk.TopKWeightAndReduce,
output.copy_(fused_expert_output) ) -> None:
else: if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): weight_and_reduce_impl = TopKWeightAndReduceContiguous()
weight_and_reduce_impl = TopKWeightAndReduceContiguous() weight_and_reduce_impl.apply(
weight_and_reduce_impl.apply( output=output,
output=output, fused_expert_output=fused_expert_output,
fused_expert_output=fused_expert_output, topk_weights=topk_weights,
topk_weights=topk_weights, topk_ids=topk_ids,
topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input)
apply_router_weight_on_input=apply_router_weight_on_input)

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Optional
import torch import torch
@ -119,18 +119,28 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts, local_num_experts,
expert_tokens_meta) expert_tokens_meta)
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, def apply(
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, self,
topk_ids: torch.Tensor, activation: str, global_num_experts: int, output: torch.Tensor,
expert_map: Optional[torch.Tensor], hidden_states: torch.Tensor,
w1_scale: Optional[torch.Tensor], w1: torch.Tensor,
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], w2: torch.Tensor,
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, topk_ids: torch.Tensor,
workspace2: torch.Tensor, activation: str,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata], global_num_experts: int,
apply_router_weight_on_input: bool, expert_map: Optional[torch.Tensor],
extra_expert_args: Optional[dict[str, Any]]): w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
use_deep_gemm = (self.allow_deep_gemm use_deep_gemm = (self.allow_deep_gemm
and (_valid_deep_gemm(hidden_states, w1, w2) and (_valid_deep_gemm(hidden_states, w1, w2)
or is_blackwell_deep_gemm_e8m0_used())) or is_blackwell_deep_gemm_e8m0_used()))
@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2, workspace2,
expert_tokens_meta, expert_tokens_meta,
apply_router_weight_on_input, apply_router_weight_on_input,
extra_expert_args,
) )

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from math import prod from math import prod
from typing import Any, Optional, Union from typing import Optional, Union
import torch import torch
@ -189,7 +189,7 @@ def moe_kernel_quantize_input(
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.uint8: # nvfp4 elif quant_dtype == "nvfp4":
return _fp4_quantize(A, return _fp4_quantize(A,
A_scale, A_scale,
is_sf_swizzled_layout=is_fp4_scale_swizzled) is_sf_swizzled_layout=is_fp4_scale_swizzled)
@ -252,17 +252,3 @@ def _validate_scale_shape(
assert block_shape is not None assert block_shape is not None
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
def extract_required_args(
extra_args: Optional[dict[str, Any]],
required_keys: list[str],
) -> tuple[Any, ...]:
if extra_args is None:
raise ValueError("`extra_args` must be provided.")
missing_keys = [k for k in required_keys if k not in extra_args]
if missing_keys:
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
return tuple(extra_args[k] for k in required_keys)

View File

@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
if isinstance(layer, FusedMoE): if isinstance(layer, FusedMoE):
if use_marlin: if use_marlin:
return AWQMoEMethod(quant_args_marlin) return AWQMoEMethod(quant_args_marlin, layer.moe)
from vllm.model_executor.layers.quantization.moe_wna16 import ( from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config) MoeWNA16Config)
@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig):
} }
return MoeWNA16Config.from_config(config).get_quant_method( return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix) layer, prefix)
return GPTQMarlinMoEMethod(quant_args_marlin) return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
if isinstance(layer, (LinearBase, ParallelLMHead)): if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin: if use_marlin:

View File

@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig):
} }
awq_marlin_config = AWQMarlinConfig.from_config( awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict) marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config) return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return None return None

View File

@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
"Falling back to Moe WNA16 kernels.") "Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config( return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
return AWQMoEMethod(self) return AWQMoEMethod(self, layer.moe_config)
return None return None
@classmethod @classmethod
@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
class AWQMoEMethod(FusedMoEMethodBase): class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig): def __init__(
self,
quant_config: AWQMarlinConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.weight_bits != 4: if self.quant_config.weight_bits != 4:
raise ValueError("AWQMoEMethod only supports 4bit now.") raise ValueError("AWQMoEMethod only supports 4bit now.")
@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `AWQMoEMethod` yet.") "EPLB not supported for `AWQMoEMethod` yet.")
@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map=expert_map, expert_map=expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
workspace=layer.workspace) workspace=layer.workspace)

View File

@ -7,6 +7,7 @@ import torch
from packaging import version from packaging import version
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self) return BitsAndBytesLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return BitsAndBytesMoEMethod(self) return BitsAndBytesMoEMethod(self, layer.moe_config)
return None return None
@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
quant_config: The BitsAndBytes quantization config. quant_config: The BitsAndBytes quantization config.
""" """
def __init__(self, quant_config: BitsAndBytesConfig): def __init__(
self,
quant_config: BitsAndBytesConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
try: try:
import bitsandbytes import bitsandbytes
if version.parse( if version.parse(
@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
raise ImportError("Please install bitsandbytes>=0.46.1 via " raise ImportError("Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use " "`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer.") from err "bitsandbytes quantizer.") from err
self.topk_indices_dtype = None
self.quant_config = quant_config self.quant_config = quant_config
def create_weights( def create_weights(
@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(

View File

@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering,
QuantizationStrategy) QuantizationStrategy)
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferCutlassMoEPrepareAndFinalize) is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel, build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer, marlin_make_workspace_new, check_moe_marlin_supports_layer, marlin_make_workspace_new,
marlin_moe_permute_scales) marlin_moe_permute_scales)
@ -58,6 +59,9 @@ __all__ = [
class CompressedTensorsMoEMethod(FusedMoEMethodBase): class CompressedTensorsMoEMethod(FusedMoEMethodBase):
def __init_(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
"WNA16MoE is not supported with actorder=group/dynamic." "WNA16MoE is not supported with actorder=group/dynamic."
) )
logger.info_once("Using CompressedTensorsWNA16MoEMethod") logger.info_once("Using CompressedTensorsWNA16MoEMethod")
return CompressedTensorsWNA16MoEMethod(quant_config) return CompressedTensorsWNA16MoEMethod(quant_config,
layer.moe_config)
else: else:
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
return CompressedTensorsWNA16MarlinMoEMethod(quant_config) return CompressedTensorsWNA16MarlinMoEMethod(
quant_config, layer.moe_config)
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A4MoeMethod() return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant) or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
or quant_config._is_fp8_w8a8(weight_quant, input_quant)): or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config) return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
layer.moe_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config) return CompressedTensorsW8A8Int8MoEMethod(quant_config,
layer.moe_config)
else: else:
raise RuntimeError( raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod): class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
def __init__(self): def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support) detect_nvfp4_moe_support)
super().__init__(moe)
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
self.use_marlin = _nvfp4.use_marlin self.use_marlin = _nvfp4.use_marlin
self.group_size = 16 self.group_size = 16
self.fused_experts = None # type: ignore[assignment] self.layer = layer
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale_quant = torch.nn.Parameter( layer.w2_input_scale_quant = torch.nn.Parameter(
(layer.w2_input_global_scale), requires_grad=False) (layer.w2_input_global_scale), requires_grad=False)
def maybe_swap_experts_impl(self, moe_parallel_config): def maybe_make_prepare_finalize(
self,
moe: FusedMoEConfig,
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer: if not self.allow_flashinfer:
return return super().maybe_make_prepare_finalize(moe)
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
def select_gemm_impl(self, prepare_finalize, moe): prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe,
a1_gscale=self.layer.w13_input_scale_quant,
)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation.""" """Return the appropriate GEMM experts implementation."""
assert moe is not None and prepare_finalize is not None experts = select_nvfp4_gemm_impl(
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 moe,
select_nvfp4_gemm_impl) g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def apply( def apply(
self, self,
@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for " raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet.") "`CompressedTensorsW4A4MoeMethod` yet.")
@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
if self.use_marlin: if self.use_marlin:
@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
# FlashInfer fused experts path # FlashInfer fused experts path
if self.fused_experts is not None: if self.fused_experts is not None:
return flashinfer_fp4_cutlass_moe_forward( assert is_valid_flashinfer_cutlass_fused_moe(
self.fused_experts, x, layer.w13_weight, layer.w2_weight), (
layer, "Flashinfer CUTLASS Fused MoE not applicable!")
x,
topk_weights, return self.fused_experts(
topk_ids, hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )
@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
device=x.device,
apply_router_weight_on_input=apply_router_weight_on_input).to( apply_router_weight_on_input=apply_router_weight_on_input).to(
x.dtype) x.dtype)
@ -384,15 +419,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights") "weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get( self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
self.topk_indices_dtype = None
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy and self.input_quant.strategy
@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.weight_quant, self.input_quant) self.weight_quant, self.input_quant)
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90( self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100) self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
self.fused_experts = None # type: ignore[assignment]
self.disable_expert_map = False self.disable_expert_map = False
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -614,25 +649,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
) -> FusedMoEPermuteExpertsUnpermute: ) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path # cutlass path
if self.use_cutlass: if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
use_batched_format = (prepare_finalize.activation_format == experts: FusedMoEPermuteExpertsUnpermute
FusedMoEActivationFormat.BatchedExperts)
num_dispatchers = prepare_finalize.num_dispatchers() num_dispatchers = prepare_finalize.num_dispatchers()
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
experts = CutlassExpertsFp8( logger.debug("CutlassBatchedExpertsFp8(%s)",
num_experts, self.__class__.__name__)
moe.in_dtype, experts = CutlassBatchedExpertsFp8(
self.input_quant.strategy == QuantizationStrategy.TOKEN, moe.num_local_experts,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL, num_dispatchers,
num_dispatchers=num_dispatchers, moe.in_dtype,
use_batched_format=use_batched_format, self.input_quant.strategy == QuantizationStrategy.TOKEN,
) self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
)
self.disable_expert_map = (num_dispatchers > 1 self.disable_expert_map = (num_dispatchers > 1
or not experts.supports_expert_map()) or not experts.supports_expert_map())
@ -834,9 +875,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights") "weights")
@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
@ -975,9 +1021,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for " "EPLB not supported for "
@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,
@ -1279,9 +1330,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__( def __init__(
self, self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
): ):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels # TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError("EPLB not supported for " raise NotImplementedError("EPLB not supported for "
"`CompressedTensorsWNA16MoEMethod` yet.") "`CompressedTensorsWNA16MoEMethod` yet.")
@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
x, x,

View File

@ -6,7 +6,8 @@ from typing import Any, Callable, Optional
import torch import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self) return ExpertsInt8MoEMethod(self, layer.moe_config)
return None return None
class ExpertsInt8MoEMethod(FusedMoEMethodBase): class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ExpertsInt8Config): def __init__(
self,
quant_config: ExpertsInt8Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ExpertsInt8MoEMethod` yet.") "EPLB not supported for `ExpertsInt8MoEMethod` yet.")
@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
x, x,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import TYPE_CHECKING, Any, Callable, Optional from typing import TYPE_CHECKING, Any, Callable, Optional
import torch import torch
@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self) return Fp8MoEMethod(self, layer.moe_config)
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
return None return None
@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
super().__init__(moe)
from vllm.model_executor.layers.fused_moe import fused_experts
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"CutlassBlockScaledGroupedGemm not supported on the current " "CutlassBlockScaledGroupedGemm not supported on the current "
"platform.") "platform.")
self.topk_indices_dtype = None
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group, topk_group=topk_group,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
else: elif self.fused_experts is not None:
return self.fused_experts( return self.fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
) )
else:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale_inv
if self.block_quant else layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale_inv
if self.block_quant else layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
allow_cutlass_block_scaled_grouped_gemm=(
self.allow_cutlass_block_scaled_grouped_gemm))
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):

View File

@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig):
elif isinstance(layer, VocabParallelEmbedding): elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self) return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self) return GGUFMoEMethod(self, layer.moe_config)
return None return None
@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase):
quant_config: The GGUF quantization config. quant_config: The GGUF quantization config.
""" """
def __init__(self, quant_config: GGUFConfig): def __init__(
self,
quant_config: GGUFConfig,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ):
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `GGUFMoEMethod` yet.") "EPLB not supported for `GGUFMoEMethod` yet.")
@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight, return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
topk_weights, topk_ids, topk_weights, topk_ids,
layer.w13_qweight_type.weight_type, layer.w13_qweight_type.weight_type,

View File

@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod) UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
class GPTQMarlinMoEMethod(FusedMoEMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase):
"""MoE Marlin method with quantization.""" """MoE Marlin method with quantization."""
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(
self,
quant_config: GPTQMarlinConfig,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4: if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8 self.quant_type = scalar_types.uint4b8
@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `GPTQMarlinMoEMethod` yet.") "EPLB not supported for `GPTQMarlinMoEMethod` yet.")
@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
x, x,

View File

@ -12,7 +12,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_kernel, build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1) select_nvfp4_gemm_impl)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors, apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31) rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig):
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self) return ModelOptFp8MoEMethod(self, layer.moe_config)
return None return None
@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
quant_config: The ModelOpt quantization config. quant_config: The ModelOpt quantization config.
""" """
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(
self,
quant_config: ModelOptFp8Config,
moe: FusedMoEConfig,
) -> None:
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported) cutlass_fp8_supported)
@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `ModelOptFp8MoEMethod` yet.") "EPLB not supported for `ModelOptFp8MoEMethod` yet.")
@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoE(self) return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
return None return None
@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
quant_config: NVFP4 Quant Config quant_config: NVFP4 Quant Config
""" """
def __init__(self, quant_config: ModelOptNvFp4Config) -> None: def __init__(
self.quant_config = quant_config self,
quant_config: ModelOptNvFp4Config,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> None:
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501 from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
detect_nvfp4_moe_support) detect_nvfp4_moe_support)
super().__init__(moe)
self.quant_config = quant_config
self.layer = layer
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
self.allow_flashinfer = _nvfp4.allow_flashinfer self.allow_flashinfer = _nvfp4.allow_flashinfer
@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self.fused_experts: Optional[ self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore[assignment] mk.FusedMoEModularKernel] = None # type: ignore[assignment]
def maybe_swap_experts_impl( def maybe_make_prepare_finalize(
self, self,
moe_parallel_config: FusedMoEParallelConfig, moe: FusedMoEConfig,
): ) -> Optional[mk.FusedMoEPrepareAndFinalize]:
if not self.allow_flashinfer: if not self.allow_flashinfer:
return return super().maybe_make_prepare_finalize(moe)
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
moe_parallel_config)
# This method update self.fused_experts prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
# only prepare_finalize is not None call select_gemm_impl moe,
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert a1_gscale=self.layer.w13_input_scale_quant,
# when it's not called(TP case), we still have 2 kernels to use. )
def select_gemm_impl(self, prepare_finalize, logger.debug_once("%s", prepare_finalize.__class__.__name__)
moe) -> mk.FusedMoEPermuteExpertsUnpermute: return prepare_finalize
assert moe is not None and prepare_finalize is not None def select_gemm_impl(
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501 self,
select_nvfp4_gemm_impl) prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger) ) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,
g1_alphas=self.layer.g1_alphas,
g2_alphas=self.layer.g2_alphas,
a1_gscale=self.layer.w13_input_scale_quant,
a2_gscale=self.layer.w2_input_scale_quant,
allow_flashinfer=self.allow_flashinfer,
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
def uses_weight_scale_2_pattern(self) -> bool: def uses_weight_scale_2_pattern(self) -> bool:
""" """
@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.use_marlin: if self.use_marlin:
return torch.ops.vllm.fused_marlin_moe( return torch.ops.vllm.fused_marlin_moe(
@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
n=layer.w2_weight.shape[2] * 2, n=layer.w2_weight.shape[2] * 2,
k=x.shape[1], k=x.shape[1],
e=layer.w13_weight.shape[0], e=layer.w13_weight.shape[0],
device=x.device,
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input) apply_router_weight_on_input=apply_router_weight_on_input)
else: else:
assert self.allow_flashinfer and \ assert self.allow_flashinfer and \
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
out = flashinfer_fp4_cutlass_moe_forward(
self.fused_experts, assert is_valid_flashinfer_cutlass_fused_moe(
layer, x, layer.w13_weight, layer.w2_weight), (
x, "Flashinfer CUTLASS Fused MoE not applicable!")
topk_weights,
topk_ids, out = self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
) )

View File

@ -7,7 +7,7 @@ import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig):
else: else:
raise ValueError("moe_wna16 only support gptq and awq.") raise ValueError("moe_wna16 only support gptq and awq.")
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return MoeWNA16Method(self) return MoeWNA16Method(self, layer.moe_config)
return None return None
@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
""" """
def __init__(self, quant_config: MoeWNA16Config): def __init__(
self,
quant_config: MoeWNA16Config,
moe: FusedMoEConfig,
):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `MoeWNA16Method` yet.") "EPLB not supported for `MoeWNA16Method` yet.")
@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp

View File

@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig):
class Mxfp4MoEMethod(FusedMoEMethodBase): class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig): def __init__(self, moe: FusedMoEConfig):
super().__init__() super().__init__(moe)
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.moe = moe self.moe = moe
self.use_marlin = self._should_use_marlin() self.use_marlin = self._should_use_marlin()

View File

@ -7,7 +7,8 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE) OCP_MX_BLOCK_SIZE)
@ -25,6 +26,9 @@ __all__ = [
class QuarkMoEMethod(FusedMoEMethodBase): class QuarkMoEMethod(FusedMoEMethodBase):
def __init__(self, moe: FusedMoEConfig):
super().__init__(moe)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase):
input_config = layer_quant_config.get("input_tensors") input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config): if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config) return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
module.moe_config)
elif quant_config._is_mx_fp4(weight_config, input_config): elif quant_config._is_mx_fp4(weight_config, input_config):
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
module.moe_config)
else: else:
raise RuntimeError("Unsupported FusedMoe scheme") raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, def __init__(
Any]): self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config self.weight_quant = weight_config
self.input_quant = input_config self.input_quant = input_config
@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
return fused_experts( return fused_experts(
x, x,
@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, def __init__(
Any]): self,
weight_config: dict[str, Any],
input_config: dict[str, Any],
moe: FusedMoEConfig,
):
super().__init__(moe)
self.weight_quant = weight_config self.weight_quant = weight_config
self.input_quant = input_config self.input_quant = input_config
@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
out = fused_experts( out = fused_experts(
x, x,

View File

@ -10,7 +10,8 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return RTNLinearMethod(self) return RTNLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return RTNMoEMethod(self) return RTNMoEMethod(self, layer.moe_config)
return None return None
@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
class RTNMoEMethod(FusedMoEMethodBase): class RTNMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: RTNConfig): def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
super().__init__(moe)
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert self.fused_experts is None
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `RTNMoEMethod` yet.") "EPLB not supported for `RTNMoEMethod` yet.")
@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size group_size = self.quant_config.group_size

View File

@ -3,33 +3,30 @@
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path""" """Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
from __future__ import annotations from __future__ import annotations
from typing import Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe) FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize) FlashInferCutlassMoEPrepareAndFinalize)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
logger = init_logger(__name__)
__all__ = [ __all__ = [
"is_flashinfer_fp4_cutlass_moe_available", "is_flashinfer_fp4_cutlass_moe_available",
"reorder_w1w3_to_w3w1", "reorder_w1w3_to_w3w1",
"build_flashinfer_fp4_cutlass_moe_kernel", "build_flashinfer_fp4_cutlass_moe_prepare_finalize",
"flashinfer_fp4_cutlass_moe_forward",
] ]
def is_flashinfer_fp4_cutlass_moe_available() -> bool: def is_flashinfer_fp4_cutlass_moe_available() -> bool:
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" """Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda() return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and current_platform.is_cuda()
and current_platform.is_device_capability(100)) and current_platform.is_device_capability(100))
@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
dim=dim).contiguous()) dim=dim).contiguous())
def build_flashinfer_fp4_cutlass_moe_kernel( def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel: moe: FusedMoEConfig,
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel""" a1_gscale: torch.Tensor,
experts = FlashInferExperts( ) -> mk.FusedMoEPrepareAndFinalize:
use_nvfp4_w4a4=True, """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp=moe_parallel_config.dp_size > 1, use_dp = moe.moe_parallel_config.dp_size > 1
ep_rank=moe_parallel_config.ep_rank, return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
ep_size=moe_parallel_config.ep_size,
tp_rank=moe_parallel_config.tp_rank,
tp_size=moe_parallel_config.tp_size,
)
logger.debug_once("FlashInferExperts (util)")
return mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
experts,
)
def flashinfer_fp4_cutlass_moe_forward(
fused_experts: mk.FusedMoEModularKernel,
layer: torch.nn.Module,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> torch.Tensor:
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
assert is_valid_flashinfer_cutlass_fused_moe(
x, layer.w13_weight,
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
extra_expert_args = {
"g1_alphas": layer.g1_alphas,
"g2_alphas": layer.g2_alphas,
# Avoid confusion with a1_scale and a2_scale
# where are batch size related.
"a1_gscale": a1_gscale,
"a2_gscale": a2_gscale,
"out_dtype": x.dtype,
}
extra_prepare_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
"a1_gscale": a1_gscale,
}
extra_finalize_args = {
"use_dp": layer.dp_size > 1,
"local_tokens": x.shape[0],
}
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False, # TODO(shuw): fix later, now output is high prec
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_blockscale_swizzled,
w2_scale=layer.w2_blockscale_swizzled,
apply_router_weight_on_input=apply_router_weight_on_input,
extra_expert_args=extra_expert_args,
extra_prepare_args=extra_prepare_args,
extra_finalize_args=extra_finalize_args,
)
def select_nvfp4_gemm_impl( def select_nvfp4_gemm_impl(
allow_flashinfer: bool, moe: FusedMoEConfig,
moe, # FusedMoEConfig g1_alphas: torch.Tensor,
logger): g2_alphas: torch.Tensor,
a1_gscale: torch.Tensor,
a2_gscale: torch.Tensor,
allow_flashinfer: bool,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
# lazy import
from vllm.distributed import get_ep_group
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if allow_flashinfer: if allow_flashinfer:
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
if flashinfer_backend != "throughput":
raise ValueError(
f"Only throughput backend is supported for FlashInferExperts, "
f"but got {flashinfer_backend}.")
logger.debug_once(
"Initializing FlashInferExperts with throughput backend.")
return FlashInferExperts( return FlashInferExperts(
use_nvfp4_w4a4=True, g1_alphas=g1_alphas,
use_dp=moe.moe_parallel_config.dp_size > 1, g2_alphas=g2_alphas,
a1_gscale=a1_gscale,
a2_gscale=a2_gscale,
out_dtype=moe.in_dtype,
quant_dtype="nvfp4",
ep_rank=moe.moe_parallel_config.ep_rank, ep_rank=moe.moe_parallel_config.ep_rank,
ep_size=moe.moe_parallel_config.ep_size, ep_size=moe.moe_parallel_config.ep_size,
tp_rank=moe.moe_parallel_config.tp_rank, tp_rank=moe.moe_parallel_config.tp_rank,