mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:05:48 +08:00
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
499 lines
17 KiB
Python
499 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
|
|
# Fused experts and PrepareFinalize imports
|
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
|
BatchedDeepGemmExperts)
|
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
|
BatchedTritonOrDeepGemmExperts)
|
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
|
FusedMoEQuantConfig)
|
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
BatchedTritonExperts, NaiveBatchedExperts)
|
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
|
TritonExperts)
|
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
MoEPrepareAndFinalizeNoEP)
|
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|
TritonOrDeepGemmExperts)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
cutlass_fp4_supported)
|
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
cutlass_fp8_supported)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
|
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
|
|
|
|
|
@dataclass
|
|
class PrepareFinalizeInfo:
|
|
activation_format: mk.FusedMoEActivationFormat
|
|
supported_dtypes: list[Union[torch.dtype, str]]
|
|
blocked_quantization_support: bool
|
|
backend: Optional[str]
|
|
supports_apply_weight_on_input: bool = True
|
|
|
|
|
|
@dataclass
|
|
class ExpertInfo:
|
|
activation_format: mk.FusedMoEActivationFormat
|
|
supported_dtypes: list[Union[torch.dtype, str]]
|
|
blocked_quantization_support: bool
|
|
supports_chunking: bool
|
|
supports_expert_map: bool
|
|
needs_matching_quant: bool = False
|
|
needs_deep_gemm: bool = False
|
|
|
|
|
|
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
|
|
PrepareFinalizeInfo] = {}
|
|
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
|
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
|
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
|
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
|
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
|
|
|
standard_format = mk.FusedMoEActivationFormat.Standard
|
|
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
|
common_float_types: list[Union[torch.dtype, str]] = [
|
|
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
|
]
|
|
common_float_and_int_types = common_float_types + [torch.int8]
|
|
nv_fp4_types = ["nvfp4"]
|
|
fp8_types = [torch.float8_e4m3fn]
|
|
|
|
|
|
def register_prepare_and_finalize(
|
|
kind,
|
|
activation_format: mk.FusedMoEActivationFormat,
|
|
supported_dtypes: list[Union[torch.dtype, str]],
|
|
blocked_quantization_support: bool,
|
|
backend: Optional[str],
|
|
force_multigpu: bool = False,
|
|
supports_apply_weight_on_input: bool = True,
|
|
):
|
|
global PREPARE_FINALIZE_INFO
|
|
global MK_ALL_PREPARE_FINALIZE_TYPES
|
|
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
|
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
|
assert kind not in PREPARE_FINALIZE_INFO
|
|
|
|
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
|
|
activation_format,
|
|
supported_dtypes,
|
|
blocked_quantization_support,
|
|
backend,
|
|
supports_apply_weight_on_input,
|
|
)
|
|
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
|
|
if backend is not None or force_multigpu:
|
|
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
|
else:
|
|
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
|
|
|
|
|
def register_experts(
|
|
kind,
|
|
activation_format: mk.FusedMoEActivationFormat,
|
|
supported_dtypes: list[Union[torch.dtype, str]],
|
|
blocked_quantization_support: bool,
|
|
supports_chunking: bool,
|
|
supports_expert_map: bool,
|
|
needs_matching_quant: bool = False,
|
|
needs_deep_gemm: bool = False,
|
|
):
|
|
global EXPERT_INFO
|
|
global MK_FUSED_EXPERT_TYPES
|
|
assert kind not in EXPERT_INFO
|
|
|
|
EXPERT_INFO[kind] = ExpertInfo(
|
|
activation_format,
|
|
supported_dtypes,
|
|
blocked_quantization_support,
|
|
supports_chunking,
|
|
supports_expert_map,
|
|
needs_matching_quant,
|
|
needs_deep_gemm,
|
|
)
|
|
|
|
MK_FUSED_EXPERT_TYPES.append(kind)
|
|
|
|
|
|
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
|
|
info = PREPARE_FINALIZE_INFO.get(kind)
|
|
assert info is not None
|
|
return info
|
|
|
|
|
|
def expert_info(kind) -> ExpertInfo:
|
|
info = EXPERT_INFO.get(kind)
|
|
assert info is not None
|
|
return info
|
|
|
|
|
|
register_prepare_and_finalize(
|
|
MoEPrepareAndFinalizeNoEP,
|
|
standard_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
backend=None,
|
|
)
|
|
|
|
register_experts(
|
|
BatchedTritonExperts,
|
|
batched_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
needs_matching_quant=True,
|
|
)
|
|
|
|
register_experts(
|
|
TritonExperts,
|
|
standard_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=True,
|
|
needs_matching_quant=True,
|
|
)
|
|
|
|
register_experts(
|
|
NaiveBatchedExperts,
|
|
batched_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=True,
|
|
)
|
|
|
|
# Disable on blackwell for now
|
|
if has_deep_ep() and not current_platform.has_device_capability(100):
|
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
|
DeepEPHTPrepareAndFinalize)
|
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
|
DeepEPLLPrepareAndFinalize)
|
|
|
|
register_prepare_and_finalize(
|
|
DeepEPHTPrepareAndFinalize,
|
|
standard_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
backend="deepep_high_throughput",
|
|
)
|
|
|
|
register_prepare_and_finalize(
|
|
DeepEPLLPrepareAndFinalize,
|
|
batched_format,
|
|
common_float_types,
|
|
blocked_quantization_support=True,
|
|
backend="deepep_low_latency",
|
|
)
|
|
|
|
if has_pplx():
|
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
|
PplxPrepareAndFinalize)
|
|
register_prepare_and_finalize(
|
|
PplxPrepareAndFinalize,
|
|
batched_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
backend="pplx",
|
|
)
|
|
|
|
if (has_flashinfer_cutlass_fused_moe()
|
|
and current_platform.has_device_capability(100)):
|
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
|
FlashInferExperts)
|
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
|
FlashInferCutlassMoEPrepareAndFinalize)
|
|
|
|
register_prepare_and_finalize(
|
|
FlashInferCutlassMoEPrepareAndFinalize,
|
|
standard_format,
|
|
nv_fp4_types,
|
|
blocked_quantization_support=True,
|
|
backend=None,
|
|
force_multigpu=True,
|
|
supports_apply_weight_on_input=False,
|
|
)
|
|
|
|
register_experts(
|
|
FlashInferExperts,
|
|
standard_format,
|
|
nv_fp4_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
# Note: this is a hack to get it to run for now
|
|
supports_expert_map=True,
|
|
)
|
|
else:
|
|
FlashInferCutlassMoEPrepareAndFinalize = None
|
|
|
|
if has_deep_gemm() and is_deep_gemm_supported():
|
|
register_experts(
|
|
BatchedDeepGemmExperts,
|
|
batched_format,
|
|
fp8_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
needs_matching_quant=False,
|
|
needs_deep_gemm=True,
|
|
)
|
|
register_experts(
|
|
DeepGemmExperts,
|
|
standard_format,
|
|
fp8_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=True,
|
|
needs_matching_quant=False,
|
|
needs_deep_gemm=True,
|
|
),
|
|
register_experts(
|
|
BatchedTritonOrDeepGemmExperts,
|
|
batched_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
needs_matching_quant=True,
|
|
needs_deep_gemm=True,
|
|
)
|
|
register_experts(
|
|
TritonOrDeepGemmExperts,
|
|
standard_format,
|
|
common_float_and_int_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=True,
|
|
needs_matching_quant=True,
|
|
needs_deep_gemm=True,
|
|
)
|
|
|
|
if cutlass_fp8_supported():
|
|
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
|
|
CutlassExpertsFp8)
|
|
register_experts(
|
|
CutlassExpertsFp8,
|
|
standard_format,
|
|
fp8_types,
|
|
blocked_quantization_support=False,
|
|
supports_chunking=True,
|
|
supports_expert_map=False,
|
|
)
|
|
register_experts(
|
|
CutlassBatchedExpertsFp8,
|
|
batched_format,
|
|
fp8_types,
|
|
blocked_quantization_support=False,
|
|
supports_chunking=False,
|
|
supports_expert_map=False,
|
|
)
|
|
|
|
if cutlass_fp4_supported():
|
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
|
CutlassExpertsFp4)
|
|
register_experts(
|
|
CutlassExpertsFp4,
|
|
standard_format,
|
|
nv_fp4_types,
|
|
blocked_quantization_support=True,
|
|
supports_chunking=True,
|
|
supports_expert_map=False,
|
|
)
|
|
|
|
MK_QUANT_CONFIGS = [
|
|
None,
|
|
# per-channel / per-column weights and per-tensor activations
|
|
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=True,
|
|
per_act_token_quant=False,
|
|
block_shape=None),
|
|
# per-channel / per-column weights and per-token activations
|
|
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=True,
|
|
per_act_token_quant=True,
|
|
block_shape=None),
|
|
# per-tensor weights and per-tensor activations
|
|
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=False,
|
|
block_shape=None),
|
|
# per-tensor weights and per-token activations
|
|
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=True,
|
|
block_shape=None),
|
|
# block-quantized weights and 128 block per-token activations
|
|
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=False,
|
|
block_shape=[128, 128]),
|
|
# TODO (varun) : Should we test the following combinations ?
|
|
# block-quantized weights and per-token activations
|
|
# block-quantized weights and per-tensor activations
|
|
]
|
|
|
|
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
|
MK_QUANT_CONFIGS += [
|
|
FusedMoEQuantConfig(quant_dtype="nvfp4",
|
|
per_out_ch_quant=False,
|
|
per_act_token_quant=False,
|
|
block_shape=None),
|
|
]
|
|
|
|
|
|
def _make_gscale(num_experts: int) -> torch.Tensor:
|
|
return torch.ones((num_experts, ),
|
|
device=torch.cuda.current_device(),
|
|
dtype=torch.float32)
|
|
|
|
|
|
def make_prepare_finalize(
|
|
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
|
backend: Optional[str],
|
|
moe: FusedMoEConfig,
|
|
) -> mk.FusedMoEPrepareAndFinalize:
|
|
if backend != "naive" and backend is not None:
|
|
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
|
assert prepare_finalize is not None
|
|
return prepare_finalize
|
|
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
|
return FlashInferCutlassMoEPrepareAndFinalize(
|
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
|
a1_gscale=_make_gscale(moe.num_local_experts),
|
|
)
|
|
else:
|
|
return MoEPrepareAndFinalizeNoEP()
|
|
|
|
|
|
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
|
s = rank * num_local_experts
|
|
e = s + num_local_experts
|
|
return t[s:e]
|
|
|
|
|
|
def make_fused_experts(
|
|
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
|
moe: FusedMoEConfig,
|
|
num_dispatchers: int,
|
|
w1_gs: Optional[torch.Tensor],
|
|
w2_gs: Optional[torch.Tensor],
|
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
|
|
|
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
|
|
batch_kwargs = {
|
|
"max_num_tokens": moe.max_num_tokens,
|
|
"num_dispatchers": num_dispatchers,
|
|
}
|
|
quant_kwargs = {
|
|
"use_fp8_w8a8": use_fp8,
|
|
"use_int8_w8a8": False,
|
|
"use_int8_w8a16": False,
|
|
"use_int4_w4a16": False,
|
|
"block_shape": moe.block_shape,
|
|
"per_act_token_quant": moe.per_act_token_quant,
|
|
}
|
|
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
|
|
|
if fused_experts_type == BatchedDeepGemmExperts:
|
|
kwargs = batch_kwargs | {
|
|
"block_shape": moe.block_shape,
|
|
"per_act_token_quant": moe.per_act_token_quant,
|
|
}
|
|
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
|
experts = BatchedDeepGemmExperts(**kwargs)
|
|
elif fused_experts_type == BatchedTritonExperts:
|
|
kwargs = batch_kwargs | quant_kwargs
|
|
print(f"Making BatchedTritonExperts {kwargs} ...")
|
|
experts = BatchedTritonExperts(**kwargs)
|
|
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
|
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
|
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
|
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
|
elif fused_experts_type == DeepGemmExperts:
|
|
print("Making DeepGemmExperts () ...")
|
|
experts = DeepGemmExperts()
|
|
elif fused_experts_type == TritonExperts:
|
|
kwargs = quant_kwargs
|
|
print(f"Making TritonExperts {kwargs} ...")
|
|
experts = TritonExperts(**kwargs)
|
|
elif fused_experts_type == TritonOrDeepGemmExperts:
|
|
kwargs = quant_kwargs | deepgemm_kwargs
|
|
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
|
experts = TritonOrDeepGemmExperts(**kwargs)
|
|
elif fused_experts_type == NaiveBatchedExperts:
|
|
kwargs = batch_kwargs | quant_kwargs
|
|
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
|
experts = NaiveBatchedExperts(**kwargs)
|
|
elif fused_experts_type == CutlassExpertsFp8:
|
|
kwargs = {
|
|
"out_dtype": moe.in_dtype,
|
|
"per_act_token_quant": moe.per_act_token_quant,
|
|
"per_out_ch_quant": moe.per_out_ch_quant,
|
|
"block_shape": moe.block_shape,
|
|
}
|
|
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
|
experts = CutlassExpertsFp8(**kwargs)
|
|
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
|
kwargs = {
|
|
"max_experts_per_worker": moe.num_local_experts,
|
|
"num_dispatchers": num_dispatchers,
|
|
"out_dtype": moe.in_dtype,
|
|
"per_act_token_quant": moe.per_act_token_quant,
|
|
"per_out_ch_quant": moe.per_out_ch_quant,
|
|
"block_shape": moe.block_shape,
|
|
}
|
|
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
|
experts = CutlassBatchedExpertsFp8(**kwargs)
|
|
elif fused_experts_type == CutlassExpertsFp4:
|
|
assert w1_gs is not None and w2_gs is not None
|
|
num_experts = moe.num_local_experts
|
|
rank = moe.moe_parallel_config.dp_rank
|
|
kwargs = {
|
|
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
|
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
|
"a1_gscale": _make_gscale(num_experts),
|
|
"a2_gscale": _make_gscale(num_experts),
|
|
"max_experts_per_worker": num_experts,
|
|
"out_dtype": moe.in_dtype,
|
|
"per_act_token_quant": moe.per_act_token_quant,
|
|
"per_out_ch_quant": moe.per_out_ch_quant,
|
|
"block_shape": moe.block_shape,
|
|
"num_dispatchers": num_dispatchers,
|
|
}
|
|
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
|
experts = CutlassExpertsFp4(**kwargs)
|
|
elif fused_experts_type == FlashInferExperts:
|
|
assert w1_gs is not None and w2_gs is not None
|
|
num_experts = moe.num_local_experts
|
|
rank = moe.moe_parallel_config.dp_rank
|
|
kwargs = {
|
|
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
|
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
|
"a1_gscale": _make_gscale(num_experts),
|
|
"a2_gscale": _make_gscale(num_experts),
|
|
"out_dtype": moe.in_dtype,
|
|
"quant_dtype": "nvfp4",
|
|
"ep_rank": moe.ep_rank,
|
|
"ep_size": moe.ep_size,
|
|
"tp_rank": moe.tp_rank,
|
|
"tp_size": moe.tp_size,
|
|
}
|
|
print(f"Making FlashInferExperts {kwargs} ...")
|
|
experts = FlashInferExperts(**kwargs)
|
|
else:
|
|
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
|
|
|
return experts
|