mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:35:56 +08:00
[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
parent
21dce80ea9
commit
a38b8af4c3
@ -630,6 +630,7 @@ steps:
|
|||||||
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
||||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
||||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||||
|
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||||
- vllm/v1/attention/backends/flashinfer.py
|
- vllm/v1/attention/backends/flashinfer.py
|
||||||
- vllm/compilation/fusion.py
|
- vllm/compilation/fusion.py
|
||||||
- vllm/compilation/fusion_attn.py
|
- vllm/compilation/fusion_attn.py
|
||||||
@ -650,6 +651,7 @@ steps:
|
|||||||
# Fusion
|
# Fusion
|
||||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||||
|
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||||
|
|
||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
|||||||
248
tests/kernels/moe/test_flashinfer.py
Normal file
248
tests/kernels/moe/test_flashinfer.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||||
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
|
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
|
||||||
|
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||||
|
swap_w13_to_w31)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
input_to_float8)
|
||||||
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
|
if not has_flashinfer_cutlass_fused_moe(
|
||||||
|
) or not current_platform.has_device_capability(100):
|
||||||
|
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
NUM_EXPERTS = [16]
|
||||||
|
TOP_KS = [1]
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(256, 8192, 5120),
|
||||||
|
(256, 4096, 5120),
|
||||||
|
(127, 8192, 5120),
|
||||||
|
(127, 4096, 5120),
|
||||||
|
(10, 8192, 5120),
|
||||||
|
(10, 4096, 5120),
|
||||||
|
(1, 8192, 5120),
|
||||||
|
(1, 4096, 5120),
|
||||||
|
]
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||||
|
pipeline_parallel_size=1))
|
||||||
|
vllm_config.scheduler_config.max_num_seqs = 128
|
||||||
|
vllm_config.scheduler_config.max_model_len = 8192
|
||||||
|
|
||||||
|
|
||||||
|
def quant_fp8_per_tensor_batches(a):
|
||||||
|
num_batches = a.size(0)
|
||||||
|
a_quant = []
|
||||||
|
a_scales = []
|
||||||
|
|
||||||
|
for i in range(num_batches):
|
||||||
|
a_fp8, a_global_sf = input_to_float8(a[i])
|
||||||
|
a_global_sf = 1.0 / a_global_sf
|
||||||
|
a_quant.append(a_fp8)
|
||||||
|
a_scales.append(a_global_sf)
|
||||||
|
|
||||||
|
result_a_quant = torch.stack(a_quant)
|
||||||
|
result_a_scales = torch.stack(a_scales)
|
||||||
|
|
||||||
|
return result_a_quant, result_a_scales
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestData:
|
||||||
|
hidden_states: torch.Tensor
|
||||||
|
w13_quantized: torch.Tensor
|
||||||
|
w2_quantized: torch.Tensor
|
||||||
|
a1_scale: torch.Tensor
|
||||||
|
a2_scale: torch.Tensor
|
||||||
|
w13_weight_scale: torch.Tensor
|
||||||
|
w2_weight_scale: torch.Tensor
|
||||||
|
layer: torch.nn.Module
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||||
|
reorder: bool) -> "TestData":
|
||||||
|
hidden_states = torch.randn(
|
||||||
|
(m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
|
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Scale to fp8
|
||||||
|
_, a1_scale = input_to_float8(hidden_states)
|
||||||
|
a1_scale = 1.0 / a1_scale
|
||||||
|
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
|
||||||
|
dtype=torch.float32)
|
||||||
|
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||||
|
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||||
|
|
||||||
|
layer = torch.nn.Module()
|
||||||
|
layer.w13_weight = w13_quantized.clone()
|
||||||
|
layer.w2_weight = w2_quantized.clone()
|
||||||
|
layer.w13_input_scale = a1_scale
|
||||||
|
layer.w2_input_scale = a2_scale
|
||||||
|
layer.w13_weight_scale = w13_weight_scale
|
||||||
|
layer.w2_weight_scale = w2_weight_scale
|
||||||
|
|
||||||
|
register_moe_scaling_factors(layer)
|
||||||
|
|
||||||
|
# flashinfer expects swapped rows for w13
|
||||||
|
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
|
if reorder:
|
||||||
|
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||||
|
layer.w2_weight)
|
||||||
|
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||||
|
layer.intermediate_size_per_partition = n
|
||||||
|
layer.ep_rank = 0
|
||||||
|
layer.local_num_experts = e
|
||||||
|
|
||||||
|
return TestData(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w13_quantized=w13_quantized,
|
||||||
|
w2_quantized=w2_quantized,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
w13_weight_scale=w13_weight_scale,
|
||||||
|
w2_weight_scale=w2_weight_scale,
|
||||||
|
layer=layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=td.hidden_states,
|
||||||
|
router_logits=score,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||||
|
scoring_func="softmax")
|
||||||
|
|
||||||
|
output = fused_experts(
|
||||||
|
td.hidden_states,
|
||||||
|
td.w13_quantized,
|
||||||
|
td.w2_quantized,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation="silu",
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=False,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=None,
|
||||||
|
w1_scale=td.w13_weight_scale,
|
||||||
|
w2_scale=td.w2_weight_scale,
|
||||||
|
a1_scale=td.a1_scale,
|
||||||
|
a2_scale=td.a2_scale,
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||||
|
layer=td.layer,
|
||||||
|
hidden_states=td.hidden_states,
|
||||||
|
router_logits=score,
|
||||||
|
routing_bias=None,
|
||||||
|
global_num_experts=e,
|
||||||
|
top_k=topk,
|
||||||
|
num_expert_group=None,
|
||||||
|
topk_group=None,
|
||||||
|
apply_router_weight_on_input=True)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output,
|
||||||
|
flashinfer_output,
|
||||||
|
atol=5.5e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
topk: int,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=td.hidden_states,
|
||||||
|
router_logits=score,
|
||||||
|
use_grouped_topk=False,
|
||||||
|
top_k=topk,
|
||||||
|
renormalize=False,
|
||||||
|
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||||
|
scoring_func="softmax")
|
||||||
|
|
||||||
|
output = fused_experts(
|
||||||
|
td.hidden_states,
|
||||||
|
td.w13_quantized,
|
||||||
|
td.w2_quantized,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation="silu",
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=False,
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=None,
|
||||||
|
w1_scale=td.w13_weight_scale,
|
||||||
|
w2_scale=td.w2_weight_scale,
|
||||||
|
a1_scale=td.a1_scale,
|
||||||
|
a2_scale=td.a2_scale,
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
td.layer.dp_size = 1
|
||||||
|
|
||||||
|
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||||
|
td.hidden_states,
|
||||||
|
td.layer,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
activation="silu",
|
||||||
|
global_num_experts=e,
|
||||||
|
expert_map=None,
|
||||||
|
apply_router_weight_on_input=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output,
|
||||||
|
flashinfer_cutlass_output,
|
||||||
|
atol=5.5e-2,
|
||||||
|
rtol=1e-2)
|
||||||
@ -61,8 +61,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=None,
|
block_shape=None,
|
||||||
))
|
))
|
||||||
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
|
assert quant_dtype in ("nvfp4", torch.float8_e4m3fn), (
|
||||||
"currently supported.")
|
"Only nvfp4,fp8 quantization are currently supported.")
|
||||||
self.ep_rank = ep_rank
|
self.ep_rank = ep_rank
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
@ -122,7 +122,8 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"""
|
"""
|
||||||
aq_m, aq_n = aq.shape
|
aq_m, aq_n = aq.shape
|
||||||
workspace2 = ()
|
workspace2 = ()
|
||||||
output_shape = (aq_m, aq_n * 2)
|
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
|
||||||
|
torch.float8_e4m3fn else (aq_m, aq_n)
|
||||||
workspace_dtype = a.dtype
|
workspace_dtype = a.dtype
|
||||||
workspace1 = output_shape
|
workspace1 = output_shape
|
||||||
# The workspace is determined by `aq`, since it comes after any
|
# The workspace is determined by `aq`, since it comes after any
|
||||||
@ -151,29 +152,39 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: Optional[bool],
|
apply_router_weight_on_input: Optional[bool],
|
||||||
):
|
):
|
||||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
if self.quant_dtype == torch.float8_e4m3fn:
|
||||||
# min because inv_scale.
|
quant_scales = [
|
||||||
|
self.g1_alphas, self.a2_gscale, self.g2_alphas, self.a1_gscale
|
||||||
|
]
|
||||||
|
|
||||||
# Ensure w1_scale and w2_scale are not None before calling view
|
a1q_scale = None # not passing input_sf in fp8
|
||||||
assert w1_scale is not None and w2_scale is not None, (
|
fc1_expert_weights = w1
|
||||||
"w1_scale and w2_scale must not "
|
fc2_expert_weights = w2
|
||||||
"be None for FlashInferExperts")
|
else:
|
||||||
|
# Ensure w1_scale and w2_scale are not None before calling view
|
||||||
|
assert w1_scale is not None and w2_scale is not None, (
|
||||||
|
"w1_scale and w2_scale must not "
|
||||||
|
"be None for FlashInferExperts")
|
||||||
|
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||||
|
# min because inv_scale.
|
||||||
|
quant_scales = [
|
||||||
|
self.a1_gscale,
|
||||||
|
w1_scale.view(torch.int32),
|
||||||
|
self.g1_alphas,
|
||||||
|
self.a2_gscale,
|
||||||
|
w2_scale.view(torch.int32),
|
||||||
|
self.g2_alphas,
|
||||||
|
]
|
||||||
|
# FlashInfer API requires weight to be long for nvfp4
|
||||||
|
fc1_expert_weights = w1.view(torch.long)
|
||||||
|
fc2_expert_weights = w2.view(torch.long)
|
||||||
|
|
||||||
quant_scales = [
|
|
||||||
self.a1_gscale,
|
|
||||||
w1_scale.view(torch.int32),
|
|
||||||
self.g1_alphas,
|
|
||||||
self.a2_gscale,
|
|
||||||
w2_scale.view(torch.int32),
|
|
||||||
self.g2_alphas,
|
|
||||||
]
|
|
||||||
_ = flashinfer_cutlass_fused_moe(
|
_ = flashinfer_cutlass_fused_moe(
|
||||||
input=hidden_states,
|
input=hidden_states,
|
||||||
token_selected_experts=topk_ids.to(torch.int),
|
token_selected_experts=topk_ids.to(torch.int),
|
||||||
token_final_scales=topk_weights,
|
token_final_scales=topk_weights,
|
||||||
# FlashInfer API requires weight to be long for nvfp4
|
fc1_expert_weights=fc1_expert_weights,
|
||||||
fc1_expert_weights=w1.view(torch.long),
|
fc2_expert_weights=fc2_expert_weights,
|
||||||
fc2_expert_weights=w2.view(torch.long),
|
|
||||||
output_dtype=self.out_dtype,
|
output_dtype=self.out_dtype,
|
||||||
quant_scales=quant_scales,
|
quant_scales=quant_scales,
|
||||||
input_sf=a1q_scale,
|
input_sf=a1q_scale,
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from torch.nn import Module
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -23,8 +24,11 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
|
||||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||||
|
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
|
||||||
|
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||||
|
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
@ -145,7 +149,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Fp8MoEMethod(self, layer.moe_config)
|
return Fp8MoEMethod(self, layer)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return Fp8KVCacheMethod(self)
|
return Fp8KVCacheMethod(self)
|
||||||
return None
|
return None
|
||||||
@ -482,16 +486,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
|
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
|
||||||
super().__init__(moe)
|
super().__init__(layer.moe_config)
|
||||||
|
self.layer = layer
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
self.flashinfer_moe_enabled = False
|
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||||
|
self.fused_experts: Optional[
|
||||||
|
mk.FusedMoEModularKernel] = None # type: ignore
|
||||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||||
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
|
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||||
self.flashinfer_moe_enabled = True
|
)
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
# kernel for fast weight-only FP8 quantization
|
# kernel for fast weight-only FP8 quantization
|
||||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||||
@ -531,6 +539,20 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||||
"platform.")
|
"platform.")
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
|
if self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||||
|
return super().maybe_make_prepare_finalize(moe)
|
||||||
|
|
||||||
|
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
|
moe,
|
||||||
|
layer=self.layer,
|
||||||
|
)
|
||||||
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
|
return prepare_finalize
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -678,7 +700,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
layer.w2_weight, layer.w2_weight_scale_inv,
|
layer.w2_weight, layer.w2_weight_scale_inv,
|
||||||
layer.w2_input_scale)
|
layer.w2_input_scale)
|
||||||
elif self.flashinfer_moe_enabled:
|
elif self.flashinfer_moe_backend is not None:
|
||||||
# NOTE: weights have to be swapped since the activation is
|
# NOTE: weights have to be swapped since the activation is
|
||||||
# applied on different half for flashinfer vs vllm
|
# applied on different half for flashinfer vs vllm
|
||||||
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
@ -686,9 +708,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight_scale_inv.data)
|
layer.w13_weight_scale_inv.data)
|
||||||
w2_weight = layer.w2_weight.data
|
w2_weight = layer.w2_weight.data
|
||||||
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
|
||||||
if not self.block_quant:
|
|
||||||
register_moe_scaling_factors(layer)
|
|
||||||
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
|
||||||
else:
|
else:
|
||||||
w13_weight = layer.w13_weight.data
|
w13_weight = layer.w13_weight.data
|
||||||
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
|
||||||
@ -834,6 +853,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
if self.flashinfer_moe_backend is not None:
|
||||||
|
# NOTE: weights have to be swapped since the activation is
|
||||||
|
# applied on different half for flashinfer vs vllm
|
||||||
|
assert not self.block_quant
|
||||||
|
register_moe_scaling_factors(layer)
|
||||||
|
w13_weight = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
|
if self.flashinfer_moe_backend == \
|
||||||
|
FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
|
||||||
|
layer.w13_weight.data = w13_weight.data
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||||
# Activations not quantized for marlin.
|
# Activations not quantized for marlin.
|
||||||
@ -892,6 +922,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
)
|
)
|
||||||
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
|
experts = select_cutlass_fp8_gemm_impl(
|
||||||
|
moe,
|
||||||
|
self.layer,
|
||||||
|
)
|
||||||
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
|
return experts
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||||
@ -930,25 +967,66 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
assert logical_to_physical_map is not None
|
assert logical_to_physical_map is not None
|
||||||
assert logical_replica_count is not None
|
assert logical_replica_count is not None
|
||||||
assert isinstance(layer, FusedMoE)
|
assert isinstance(layer, FusedMoE)
|
||||||
if not self.flashinfer_moe_enabled:
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
hidden_states=x,
|
assert activation == 'silu', (
|
||||||
router_logits=router_logits,
|
f"Expected 'silu' activation but got {activation}")
|
||||||
use_grouped_topk=use_grouped_topk,
|
assert scoring_func == 'sigmoid', (
|
||||||
top_k=top_k,
|
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||||
renormalize=renormalize,
|
if self.block_quant:
|
||||||
topk_group=topk_group,
|
assert (renormalize and use_grouped_topk
|
||||||
num_expert_group=num_expert_group,
|
and custom_routing_function is None)
|
||||||
custom_routing_function=custom_routing_function,
|
|
||||||
scoring_func=scoring_func,
|
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
routing_logits=router_logits.to(torch.float32),
|
||||||
indices_type=self.topk_indices_dtype,
|
routing_bias=e_score_correction_bias,
|
||||||
enable_eplb=enable_eplb,
|
x=x,
|
||||||
expert_map=expert_map,
|
w13_weight=layer.w13_weight,
|
||||||
expert_load_view=expert_load_view,
|
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
||||||
logical_to_physical_map=logical_to_physical_map,
|
w2_weight=layer.w2_weight,
|
||||||
logical_replica_count=logical_replica_count,
|
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
||||||
)
|
global_num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
intermediate_size=layer.intermediate_size_per_partition,
|
||||||
|
expert_offset=layer.ep_rank * layer.local_num_experts,
|
||||||
|
local_num_experts=layer.local_num_experts,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
routed_scaling=1.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert (not renormalize
|
||||||
|
and custom_routing_function is not None)
|
||||||
|
return apply_flashinfer_per_tensor_scale_fp8(
|
||||||
|
layer=layer,
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
routing_bias=e_score_correction_bias,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
top_k=top_k,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype,
|
||||||
|
enable_eplb=enable_eplb,
|
||||||
|
expert_map=expert_map,
|
||||||
|
expert_load_view=expert_load_view,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
|
logical_replica_count=logical_replica_count,
|
||||||
|
)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
@ -988,63 +1066,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
elif self.flashinfer_moe_enabled:
|
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
assert activation == 'silu'
|
assert self.block_quant is None
|
||||||
assert scoring_func == 'sigmoid'
|
assert (not renormalize and custom_routing_function is not None)
|
||||||
if self.block_quant:
|
assert activation == 'silu', (
|
||||||
assert (renormalize and use_grouped_topk
|
f"Expected 'silu' activation but got {activation}")
|
||||||
and custom_routing_function is None)
|
assert scoring_func == 'sigmoid', (
|
||||||
|
f"Expected 'sigmoid' scoring func but got {scoring_func}")
|
||||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
if self.fused_experts is not None:
|
||||||
routing_logits=router_logits.to(torch.float32),
|
return self.fused_experts(
|
||||||
routing_bias=e_score_correction_bias,
|
x,
|
||||||
x=x,
|
layer.w13_weight,
|
||||||
w13_weight=layer.w13_weight,
|
layer.w2_weight,
|
||||||
w13_weight_scale_inv=layer.w13_weight_scale_inv,
|
topk_weights,
|
||||||
w2_weight=layer.w2_weight,
|
topk_ids,
|
||||||
w2_weight_scale_inv=layer.w2_weight_scale_inv,
|
inplace=False,
|
||||||
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
top_k=top_k,
|
expert_map=expert_map,
|
||||||
num_expert_group=num_expert_group,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
topk_group=topk_group,
|
|
||||||
intermediate_size=layer.intermediate_size_per_partition,
|
|
||||||
expert_offset=layer.ep_rank * layer.local_num_experts,
|
|
||||||
local_num_experts=layer.local_num_experts,
|
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
routed_scaling=1.0,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (not renormalize
|
return flashinfer_cutlass_moe_fp8(
|
||||||
and custom_routing_function is not None)
|
x,
|
||||||
return apply_flashinfer_per_tensor_scale_fp8(
|
layer,
|
||||||
layer=layer,
|
topk_weights,
|
||||||
hidden_states=x,
|
topk_ids,
|
||||||
router_logits=router_logits,
|
inplace=False,
|
||||||
routing_bias=e_score_correction_bias,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
top_k=top_k,
|
expert_map=expert_map,
|
||||||
num_expert_group=num_expert_group,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
topk_group=topk_group,
|
)
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
||||||
elif self.fused_experts is not None:
|
|
||||||
return self.fused_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=True,
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
expert_map=expert_map,
|
|
||||||
w1_scale=(layer.w13_weight_scale_inv
|
|
||||||
if self.block_quant else layer.w13_weight_scale),
|
|
||||||
w2_scale=(layer.w2_weight_scale_inv
|
|
||||||
if self.block_quant else layer.w2_weight_scale),
|
|
||||||
a1_scale=layer.w13_input_scale,
|
|
||||||
a2_scale=layer.w2_input_scale,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -27,8 +26,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
|||||||
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||||
select_nvfp4_gemm_impl)
|
select_nvfp4_gemm_impl)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8,
|
||||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
|
||||||
|
flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend,
|
||||||
|
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||||
|
select_cutlass_fp8_gemm_impl, swap_w13_to_w31)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
apply_fp4_marlin_linear, is_fp4_marlin_supported,
|
||||||
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin)
|
||||||
@ -49,11 +51,6 @@ QUANT_ALGOS = ["FP8", "NVFP4"]
|
|||||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||||
|
|
||||||
|
|
||||||
class FlashinferMoeBackend(Enum):
|
|
||||||
TENSORRT_LLM = "TensorRT-LLM"
|
|
||||||
CUTLASS = "CUTLASS"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8Config(QuantizationConfig):
|
class ModelOptFp8Config(QuantizationConfig):
|
||||||
"""Config class for ModelOpt FP8."""
|
"""Config class for ModelOpt FP8."""
|
||||||
|
|
||||||
@ -179,7 +176,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return ModelOptFp8MoEMethod(self, layer.moe_config)
|
return ModelOptFp8MoEMethod(self, layer)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -278,18 +275,49 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: ModelOptFp8Config,
|
quant_config: ModelOptFp8Config,
|
||||||
moe: FusedMoEConfig,
|
layer: torch.nn.Module,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(moe)
|
super().__init__(layer.moe_config)
|
||||||
|
self.layer = layer
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_fp8_supported)
|
cutlass_fp8_supported)
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
self.flashinfer_moe_enabled = False
|
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||||
|
self.fused_experts: Optional[
|
||||||
|
mk.FusedMoEModularKernel] = None # type: ignore
|
||||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||||
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using FlashInfer MoE FP8 kernels for ModelOptFp8MoEMethod.")
|
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||||
self.flashinfer_moe_enabled = True
|
)
|
||||||
|
|
||||||
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
|
if self.fused_experts is not None or \
|
||||||
|
self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS:
|
||||||
|
return super().maybe_make_prepare_finalize(moe)
|
||||||
|
|
||||||
|
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
|
moe,
|
||||||
|
layer=self.layer,
|
||||||
|
)
|
||||||
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
|
return prepare_finalize
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
experts = select_cutlass_fp8_gemm_impl(
|
||||||
|
moe,
|
||||||
|
self.layer,
|
||||||
|
)
|
||||||
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
|
return experts
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -433,11 +461,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
layer.w2_input_scale = Parameter(layer.w2_input_scale.max(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
if self.flashinfer_moe_enabled:
|
if self.flashinfer_moe_backend is not None:
|
||||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
|
||||||
layer.w2_weight)
|
|
||||||
register_moe_scaling_factors(layer)
|
register_moe_scaling_factors(layer)
|
||||||
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
|
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||||
|
layer.w2_weight)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -461,14 +490,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert self.fused_experts is None
|
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||||
|
|
||||||
if self.flashinfer_moe_enabled:
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||||
assert activation == 'silu'
|
assert activation == 'silu', (
|
||||||
|
f"Expected 'silu' activation but got {activation}")
|
||||||
assert not renormalize
|
assert not renormalize
|
||||||
return apply_flashinfer_per_tensor_scale_fp8(
|
return apply_flashinfer_per_tensor_scale_fp8(
|
||||||
layer=layer,
|
layer=layer,
|
||||||
@ -495,6 +523,36 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
indices_type=self.topk_indices_dtype,
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
|
||||||
|
assert not renormalize
|
||||||
|
assert activation == 'silu', (
|
||||||
|
f"Expected 'silu' activation but got {activation}")
|
||||||
|
if self.fused_experts is not None:
|
||||||
|
return self.fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return flashinfer_cutlass_moe_fp8(
|
||||||
|
x,
|
||||||
|
layer,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=False,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts)
|
fused_experts)
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@ -951,20 +1009,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self.flashinfer_moe_backend = None
|
self.flashinfer_moe_backend = None
|
||||||
|
|
||||||
if self.allow_flashinfer:
|
if self.allow_flashinfer:
|
||||||
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||||
if flashinfer_moe_backend == "throughput":
|
logger.info_once(
|
||||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||||
logger.info_once("Using FlashInfer CUTLASS kernels for "
|
" for ModelOptNvFp4FusedMoE.")
|
||||||
"ModelOptNvFp4FusedMoE.")
|
|
||||||
elif flashinfer_moe_backend == "latency":
|
|
||||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
|
||||||
logger.info_once("Using FlashInfer TensorRT-LLM kernels for "
|
|
||||||
"ModelOptNvFp4FusedMoE.")
|
|
||||||
else:
|
|
||||||
allowed_backends = ["throughput", "latency"]
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
|
||||||
f" expected one of {allowed_backends}")
|
|
||||||
|
|
||||||
def maybe_make_prepare_finalize(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,9 +1,26 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm import envs
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
|
FlashInferExperts)
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FlashinferMoeBackend(Enum):
|
||||||
|
TENSORRT_LLM = "TensorRT-LLM"
|
||||||
|
CUTLASS = "CUTLASS"
|
||||||
|
|
||||||
|
|
||||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||||
|
|
||||||
@ -144,3 +161,98 @@ def register_moe_scaling_factors(layer: torch.nn.Module) -> None:
|
|||||||
layer.register_parameter(
|
layer.register_parameter(
|
||||||
'output2_scales_scalar',
|
'output2_scales_scalar',
|
||||||
torch.nn.Parameter(output2_scales, requires_grad=False))
|
torch.nn.Parameter(output2_scales, requires_grad=False))
|
||||||
|
layer.register_parameter(
|
||||||
|
'w2_input_scale_inv',
|
||||||
|
torch.nn.Parameter(1.0 / layer.w2_input_scale, requires_grad=False))
|
||||||
|
|
||||||
|
|
||||||
|
def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||||
|
moe: Optional[FusedMoEConfig],
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
|
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||||
|
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||||
|
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||||
|
use_dp, a1_gscale=layer.w13_input_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def select_cutlass_fp8_gemm_impl(
|
||||||
|
moe: Optional[FusedMoEConfig],
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
"""Return a GEMM *experts* implementation for fused-MoE layers"""
|
||||||
|
|
||||||
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
assert layer.custom_routing_function == Llama4MoE.custom_routing_function, \
|
||||||
|
"FusedMoE flashinfer kernels are only supported for Llama4"
|
||||||
|
|
||||||
|
if moe is not None:
|
||||||
|
return FlashInferExperts(
|
||||||
|
g1_alphas=layer.output1_scales_gate_scalar,
|
||||||
|
g2_alphas=layer.output2_scales_scalar,
|
||||||
|
a1_gscale=layer.w13_input_scale,
|
||||||
|
a2_gscale=layer.w2_input_scale_inv,
|
||||||
|
out_dtype=moe.in_dtype,
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||||
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
|
tp_size=moe.moe_parallel_config.tp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out_dtype is not None, (
|
||||||
|
"If moe config is None, out_dtype must be passed")
|
||||||
|
return FlashInferExperts(
|
||||||
|
g1_alphas=layer.output1_scales_gate_scalar,
|
||||||
|
g2_alphas=layer.output2_scales_scalar,
|
||||||
|
a1_gscale=layer.w13_input_scale,
|
||||||
|
a2_gscale=layer.w2_input_scale_inv,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
quant_dtype=torch.float8_e4m3fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flashinfer_cutlass_moe_fp8(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
fused_experts = mk.FusedMoEModularKernel(
|
||||||
|
build_flashinfer_fp8_cutlass_moe_prepare_finalize(moe=None,
|
||||||
|
layer=layer),
|
||||||
|
select_cutlass_fp8_gemm_impl(moe=None,
|
||||||
|
layer=layer,
|
||||||
|
out_dtype=hidden_states.dtype))
|
||||||
|
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=inplace,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
|
||||||
|
flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
||||||
|
if flashinfer_moe_backend == "throughput":
|
||||||
|
return FlashinferMoeBackend.CUTLASS
|
||||||
|
elif flashinfer_moe_backend == "latency":
|
||||||
|
return FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
|
||||||
|
allowed_backends = ["throughput", "latency"]
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown flashinfer moe backend: {flashinfer_moe_backend}"
|
||||||
|
f" expected one of {allowed_backends}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user