mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:15:35 +08:00
[Feature] Add Flashinfer MoE Support for Compressed Tensor NVFP4 (#21639)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6e672daf62
commit
c3e0e9337e
@ -17,9 +17,14 @@ from vllm.model_executor.layers.fused_moe import (
|
|||||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
|
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||||
|
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||||
marlin_moe_permute_scales)
|
marlin_moe_permute_scales)
|
||||||
@ -28,7 +33,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
prepare_moe_fp8_layer_for_marlin)
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
cutlass_fp4_supported, swizzle_blockscale)
|
swizzle_blockscale)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -96,8 +101,14 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.use_marlin = not cutlass_fp4_supported()
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
|
detect_nvfp4_moe_support)
|
||||||
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||||
|
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
||||||
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
|
self.fused_experts = None # type: ignore[assignment]
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -200,6 +211,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
|
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
# reorder GEMM1 weights and block scales for FlashInfer CUTLASS kernel.
|
||||||
|
if self.allow_flashinfer_cutlass:
|
||||||
|
w, s = reorder_w1w3_to_w3w1(layer.w13_weight.data,
|
||||||
|
layer.w13_weight_scale.data,
|
||||||
|
dim=-2)
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w, requires_grad=False)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False)
|
||||||
|
|
||||||
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
|
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
|
||||||
layer.w13_weight_global_scale[:, 1]):
|
layer.w13_weight_global_scale[:, 1]):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -246,6 +265,21 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||||
(layer.w2_input_global_scale), requires_grad=False)
|
(layer.w2_input_global_scale), requires_grad=False)
|
||||||
|
|
||||||
|
def maybe_swap_experts_impl(self, moe_parallel_config):
|
||||||
|
if not self.allow_flashinfer_cutlass:
|
||||||
|
return
|
||||||
|
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||||
|
moe_parallel_config)
|
||||||
|
|
||||||
|
def select_gemm_impl(self, prepare_finalize, moe):
|
||||||
|
"""Return the appropriate GEMM experts implementation."""
|
||||||
|
assert moe is not None and prepare_finalize is not None
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||||
|
select_nvfp4_gemm_impl)
|
||||||
|
|
||||||
|
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
||||||
|
logger)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -303,10 +337,23 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
# FlashInfer fused experts path
|
||||||
|
if self.fused_experts is not None:
|
||||||
|
return flashinfer_fp4_cutlass_moe_forward(
|
||||||
|
self.fused_experts,
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
"is currently not supported for "
|
"is currently not supported for "
|
||||||
"CompressedTensorsW4A4MoeMethod.")
|
"CompressedTensorsW4A4MoeMethod.")
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
cutlass_moe_fp4)
|
cutlass_moe_fp4)
|
||||||
|
|
||||||
|
|||||||
@ -10,11 +10,8 @@ from torch.nn.parameter import Parameter
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.distributed import get_ep_group
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
|
||||||
FlashInferCutlassMoEPrepareAndFinalize)
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
@ -23,6 +20,9 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
|
build_flashinfer_fp4_cutlass_moe_kernel,
|
||||||
|
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
apply_flashinfer_per_tensor_scale_fp8, rotate_flashinfer_fp8_moe_weights,
|
||||||
swap_w13_to_w31)
|
swap_w13_to_w31)
|
||||||
@ -35,7 +35,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
Fp8LinearOp, requantize_with_max_scale)
|
Fp8LinearOp, requantize_with_max_scale)
|
||||||
from vllm.model_executor.parameter import (ModelWeightParameter,
|
from vllm.model_executor.parameter import (ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||||
|
|
||||||
@ -869,28 +868,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
|
|
||||||
def __init__(self, quant_config: ModelOptNvFp4Config):
|
def __init__(self, quant_config: ModelOptNvFp4Config):
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
self.use_marlin = False
|
detect_nvfp4_moe_support)
|
||||||
self.allow_flashinfer_cutlass = False
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||||
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
self.allow_flashinfer_cutlass = _nvfp4.allow_flashinfer_cutlass
|
||||||
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
and current_platform.is_device_capability(100):
|
|
||||||
logger.info_once(
|
|
||||||
"Using FlashInfer kernels for ModelOptNvFp4FusedMoE.")
|
|
||||||
self.allow_flashinfer_cutlass = True
|
|
||||||
else:
|
|
||||||
logger.warning_once(
|
|
||||||
"Flashinfer CUTLASS Fused MoE not supported "
|
|
||||||
"or found on the current platform.")
|
|
||||||
|
|
||||||
if not self.cutlass_nvfp4_supported:
|
|
||||||
if is_fp4_marlin_supported():
|
|
||||||
self.use_marlin = True
|
|
||||||
else:
|
|
||||||
raise ValueError("Current platform does not support NVFP4"
|
|
||||||
" quantization. Please use Blackwell and"
|
|
||||||
" above.")
|
|
||||||
|
|
||||||
self.fused_experts = None # type: ignore
|
self.fused_experts = None # type: ignore
|
||||||
|
|
||||||
@ -900,29 +883,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
):
|
):
|
||||||
if not self.allow_flashinfer_cutlass:
|
if not self.allow_flashinfer_cutlass:
|
||||||
return
|
return
|
||||||
|
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
||||||
logger.debug_once("FlashInferExperts")
|
moe_parallel_config)
|
||||||
# default to TP/EP case only
|
|
||||||
|
|
||||||
experts_kwargs: dict[str, Any] = {
|
|
||||||
"use_nvfp4_w4a4": True,
|
|
||||||
"use_dp": moe_parallel_config.dp_size > 1,
|
|
||||||
"ep_rank": moe_parallel_config.ep_rank,
|
|
||||||
"ep_size": moe_parallel_config.ep_size,
|
|
||||||
"tp_rank": moe_parallel_config.tp_rank,
|
|
||||||
"tp_size": moe_parallel_config.tp_size,
|
|
||||||
}
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
|
||||||
FlashInferExperts)
|
|
||||||
experts = FlashInferExperts(**experts_kwargs)
|
|
||||||
self.fused_experts = mk.FusedMoEModularKernel(
|
|
||||||
FlashInferCutlassMoEPrepareAndFinalize(
|
|
||||||
quant_dtype=torch.uint8,
|
|
||||||
#meaning 2x e2m1 packed in one, kernel requirement
|
|
||||||
),
|
|
||||||
experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
# This method update self.fused_experts
|
# This method update self.fused_experts
|
||||||
# only prepare_finalize is not None call select_gemm_impl
|
# only prepare_finalize is not None call select_gemm_impl
|
||||||
@ -931,32 +893,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
def select_gemm_impl(self, prepare_finalize,
|
def select_gemm_impl(self, prepare_finalize,
|
||||||
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
|
||||||
assert moe is not None
|
assert moe is not None and prepare_finalize is not None
|
||||||
assert prepare_finalize is not None
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
||||||
experts = None
|
select_nvfp4_gemm_impl)
|
||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
|
||||||
assert all2all_manager is not None
|
|
||||||
if self.allow_flashinfer_cutlass:
|
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
|
||||||
FlashInferExperts)
|
|
||||||
logger.debug_once("Using FlashInferExperts")
|
|
||||||
experts = FlashInferExperts(
|
|
||||||
use_nvfp4_w4a4=True,
|
|
||||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
|
||||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
|
||||||
ep_size=moe.moe_parallel_config.ep_size,
|
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
|
||||||
tp_size=moe.moe_parallel_config.tp_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert moe.dp_size > 1
|
|
||||||
logger.debug_once("Using CutlassExpertsFp4")
|
|
||||||
# Currently CutlassExpertsFp4 doesn't support DP
|
|
||||||
raise ValueError("CutlassExpertsFp4 doesn't support DP. "
|
|
||||||
"Use flashinfer CUTLASS FusedMoE backend instead "
|
|
||||||
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
|
|
||||||
|
|
||||||
return experts
|
return select_nvfp4_gemm_impl(self.allow_flashinfer_cutlass, moe,
|
||||||
|
logger)
|
||||||
|
|
||||||
def uses_weight_scale_2_pattern(self) -> bool:
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -1062,18 +1004,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||||
|
|
||||||
if self.allow_flashinfer_cutlass:
|
if self.allow_flashinfer_cutlass:
|
||||||
dim = -2
|
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||||
size = gemm1_weight.size(dim)
|
gemm1_weight, gemm1_weight_scale, dim=-2)
|
||||||
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
|
||||||
half = size // 2
|
|
||||||
|
|
||||||
# Reorder weight
|
|
||||||
w1, w3 = gemm1_weight.split(half, dim=dim)
|
|
||||||
gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous()
|
|
||||||
|
|
||||||
# Reorder scale
|
|
||||||
s1, s3 = gemm1_weight_scale.split(half, dim=dim)
|
|
||||||
gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous()
|
|
||||||
|
|
||||||
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
|
layer.w13_weight = Parameter(gemm1_weight, requires_grad=False)
|
||||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
layer.w13_weight_scale = Parameter(gemm1_weight_scale,
|
||||||
@ -1217,49 +1149,15 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
else:
|
else:
|
||||||
# TP or DP case
|
out = flashinfer_fp4_cutlass_moe_forward(
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
self.fused_experts,
|
||||||
is_valid_flashinfer_cutlass_fused_moe)
|
layer,
|
||||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
x,
|
||||||
x, layer.w13_weight, layer.w2_weight), (
|
topk_weights,
|
||||||
"Flashinfer CUTLASS Fused MoE not applicable!")
|
topk_ids,
|
||||||
|
|
||||||
a1_gscale = layer.w13_input_scale_quant
|
|
||||||
a2_gscale = layer.w2_input_scale_quant
|
|
||||||
extra_expert_args = {
|
|
||||||
'g1_alphas': layer.g1_alphas,
|
|
||||||
'g2_alphas': layer.g2_alphas,
|
|
||||||
'out_dtype': x.dtype,
|
|
||||||
# Avoid confusion with a1_scale and a2_scale
|
|
||||||
# where are batch size related.
|
|
||||||
'a1_gscale': a1_gscale,
|
|
||||||
'a2_gscale': a2_gscale,
|
|
||||||
}
|
|
||||||
extra_prepare_args = {
|
|
||||||
'use_dp': layer.dp_size > 1,
|
|
||||||
'local_tokens': x.shape[0],
|
|
||||||
'a1_gscale': a1_gscale,
|
|
||||||
}
|
|
||||||
extra_finalize_args = {
|
|
||||||
'use_dp': layer.dp_size > 1,
|
|
||||||
'local_tokens': x.shape[0],
|
|
||||||
}
|
|
||||||
|
|
||||||
out = self.fused_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
|
||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=layer.w13_blockscale_swizzled,
|
|
||||||
w2_scale=layer.w2_blockscale_swizzled,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
extra_expert_args=extra_expert_args,
|
|
||||||
extra_prepare_args=extra_prepare_args,
|
|
||||||
extra_finalize_args=extra_finalize_args,
|
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -0,0 +1,154 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
|
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"is_flashinfer_fp4_cutlass_moe_available",
|
||||||
|
"reorder_w1w3_to_w3w1",
|
||||||
|
"build_flashinfer_fp4_cutlass_moe_kernel",
|
||||||
|
"flashinfer_fp4_cutlass_moe_forward",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||||
|
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||||
|
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
|
||||||
|
and current_platform.is_device_capability(100))
|
||||||
|
|
||||||
|
|
||||||
|
def reorder_w1w3_to_w3w1(weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
dim: int = -2) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Re-order the concatenated `[w1, w3]` tensors to `[w3, w1]`"""
|
||||||
|
size = weight.size(dim)
|
||||||
|
assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}"
|
||||||
|
half = size // 2
|
||||||
|
|
||||||
|
w1, w3 = weight.split(half, dim=dim)
|
||||||
|
s1, s3 = scale.split(half, dim=dim)
|
||||||
|
|
||||||
|
return (torch.cat([w3, w1],
|
||||||
|
dim=dim).contiguous(), torch.cat([s3, s1],
|
||||||
|
dim=dim).contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
def build_flashinfer_fp4_cutlass_moe_kernel(
|
||||||
|
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
|
||||||
|
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
|
||||||
|
experts = FlashInferExperts(
|
||||||
|
use_nvfp4_w4a4=True,
|
||||||
|
use_dp=moe_parallel_config.dp_size > 1,
|
||||||
|
ep_rank=moe_parallel_config.ep_rank,
|
||||||
|
ep_size=moe_parallel_config.ep_size,
|
||||||
|
tp_rank=moe_parallel_config.tp_rank,
|
||||||
|
tp_size=moe_parallel_config.tp_size,
|
||||||
|
)
|
||||||
|
logger.debug_once("FlashInferExperts (util)")
|
||||||
|
return mk.FusedMoEModularKernel(
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
|
||||||
|
experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_fp4_cutlass_moe_forward(
|
||||||
|
fused_experts: mk.FusedMoEModularKernel,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
|
||||||
|
|
||||||
|
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||||
|
x, layer.w13_weight,
|
||||||
|
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
|
||||||
|
|
||||||
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
|
a2_gscale = layer.w2_input_scale_quant
|
||||||
|
|
||||||
|
extra_expert_args = {
|
||||||
|
"g1_alphas": layer.g1_alphas,
|
||||||
|
"g2_alphas": layer.g2_alphas,
|
||||||
|
# Avoid confusion with a1_scale and a2_scale
|
||||||
|
# where are batch size related.
|
||||||
|
"a1_gscale": a1_gscale,
|
||||||
|
"a2_gscale": a2_gscale,
|
||||||
|
"out_dtype": x.dtype,
|
||||||
|
}
|
||||||
|
extra_prepare_args = {
|
||||||
|
"use_dp": layer.dp_size > 1,
|
||||||
|
"local_tokens": x.shape[0],
|
||||||
|
"a1_gscale": a1_gscale,
|
||||||
|
}
|
||||||
|
extra_finalize_args = {
|
||||||
|
"use_dp": layer.dp_size > 1,
|
||||||
|
"local_tokens": x.shape[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=layer.w13_blockscale_swizzled,
|
||||||
|
w2_scale=layer.w2_blockscale_swizzled,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
extra_expert_args=extra_expert_args,
|
||||||
|
extra_prepare_args=extra_prepare_args,
|
||||||
|
extra_finalize_args=extra_finalize_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def select_nvfp4_gemm_impl(
|
||||||
|
allow_flashinfer_cutlass: bool,
|
||||||
|
moe, # FusedMoEConfig
|
||||||
|
logger):
|
||||||
|
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||||
|
|
||||||
|
# lazy import
|
||||||
|
from vllm.distributed import get_ep_group
|
||||||
|
|
||||||
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
|
assert all2all_manager is not None
|
||||||
|
|
||||||
|
if allow_flashinfer_cutlass:
|
||||||
|
logger.debug_once("Using FlashInferExperts")
|
||||||
|
return FlashInferExperts(
|
||||||
|
use_nvfp4_w4a4=True,
|
||||||
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||||
|
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||||
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
|
tp_size=moe.moe_parallel_config.tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# native cutlass experts currently don't support DP; TP case won't call this
|
||||||
|
raise ValueError(
|
||||||
|
"CutlassExpertsFp4 doesn't support DP. Use flashinfer CUTLASS "
|
||||||
|
"Fused MoE backend instead (set VLLM_USE_FLASHINFER_MOE_FP4=1)")
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
|
is_flashinfer_fp4_cutlass_moe_available)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
|
is_fp4_marlin_supported)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
cutlass_fp4_supported)
|
||||||
|
|
||||||
|
__all__ = ["detect_nvfp4_moe_support", "NvFp4Support"]
|
||||||
|
|
||||||
|
_logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NvFp4Support:
|
||||||
|
"""Result container for NV-FP4 capability probing."""
|
||||||
|
|
||||||
|
cutlass_supported: bool
|
||||||
|
allow_flashinfer_cutlass: bool
|
||||||
|
use_marlin: bool
|
||||||
|
|
||||||
|
|
||||||
|
def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support:
|
||||||
|
"""Detect platform support for NV-FP4 fused-MoE path"""
|
||||||
|
cutlass_supported = cutlass_fp4_supported()
|
||||||
|
|
||||||
|
allow_flashinfer = (cutlass_supported
|
||||||
|
and is_flashinfer_fp4_cutlass_moe_available())
|
||||||
|
|
||||||
|
if allow_flashinfer:
|
||||||
|
_logger.info_once("Using FlashInfer kernels for %s.", class_name
|
||||||
|
or "NVFP4 path")
|
||||||
|
else:
|
||||||
|
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||||
|
_logger.warning_once(
|
||||||
|
"FlashInfer kernels unavailable for %s on current platform.",
|
||||||
|
class_name or "NVFP4 path",
|
||||||
|
)
|
||||||
|
|
||||||
|
use_marlin = False
|
||||||
|
if not cutlass_supported:
|
||||||
|
if is_fp4_marlin_supported():
|
||||||
|
use_marlin = True
|
||||||
|
_logger.info_once("Falling back to Marlin FP4 MoE kernel.")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Current platform does not support NVFP4 quantization. "
|
||||||
|
"Please use Blackwell GPUs or enable FlashInfer.")
|
||||||
|
|
||||||
|
return NvFp4Support(
|
||||||
|
cutlass_supported=cutlass_supported,
|
||||||
|
allow_flashinfer_cutlass=allow_flashinfer,
|
||||||
|
use_marlin=use_marlin,
|
||||||
|
)
|
||||||
Loading…
x
Reference in New Issue
Block a user