mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 08:57:55 +08:00
[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)
This commit is contained in:
parent
cd719de5cb
commit
39d28108f4
@ -16,11 +16,11 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
|||||||
FlashInferExperts,
|
FlashInferExperts,
|
||||||
is_valid_flashinfer_cutlass_fused_moe,
|
is_valid_flashinfer_cutlass_fused_moe,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
|
||||||
|
create_flashinfer_prepare_finalize,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
||||||
MoEPrepareAndFinalizeNoEP,
|
|
||||||
)
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
@ -48,9 +48,10 @@ MNK_FACTORS = [
|
|||||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_flashinfer_fp4_moe_no_graph(
|
def test_flashinfer_fp4_moe_no_graph(
|
||||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(7)
|
current_platform.seed_everything(7)
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
|||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
quant_blocksize = 16
|
quant_blocksize = 16
|
||||||
|
is_gated_act = activation == "silu_and_mul"
|
||||||
|
|
||||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||||
e,
|
e,
|
||||||
@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
|||||||
quant_dtype="nvfp4",
|
quant_dtype="nvfp4",
|
||||||
block_shape=None,
|
block_shape=None,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
|
make_gate=is_gated_act,
|
||||||
)
|
)
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph(
|
|||||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||||
|
|
||||||
flashinfer_experts = FusedMoEModularKernel(
|
flashinfer_experts = FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
|
||||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||||
|
|
||||||
flashinfer_output = flashinfer_experts(
|
flashinfer_output = flashinfer_experts(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
w1=w1_q,
|
w1=w1_q,
|
||||||
w2=w2_q,
|
w2=w2_q,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
activation=fi_activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference check:
|
# Reference check:
|
||||||
@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph(
|
|||||||
block_size=quant_blocksize,
|
block_size=quant_blocksize,
|
||||||
)
|
)
|
||||||
|
|
||||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
w1_d = torch.empty(
|
||||||
|
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
|
||||||
|
)
|
||||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
for idx in range(0, e):
|
for idx in range(0, e):
|
||||||
@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph(
|
|||||||
block_size=quant_blocksize,
|
block_size=quant_blocksize,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
torch_output = torch_moe(
|
||||||
|
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
|
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
|
||||||
|
|||||||
@ -264,13 +264,20 @@ def make_test_weights(
|
|||||||
quant_dtype: torch.dtype | str | None = None,
|
quant_dtype: torch.dtype | str | None = None,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int] | None = None,
|
||||||
per_out_ch_quant: bool = False,
|
per_out_ch_quant: bool = False,
|
||||||
|
make_gate: bool = True,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||||
]:
|
]:
|
||||||
return (
|
return (
|
||||||
make_test_weight(
|
make_test_weight(
|
||||||
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
|
e,
|
||||||
|
(2 if make_gate else 1) * n,
|
||||||
|
k,
|
||||||
|
in_dtype,
|
||||||
|
quant_dtype,
|
||||||
|
block_shape,
|
||||||
|
per_out_ch_quant,
|
||||||
),
|
),
|
||||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
|
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
|
||||||
)
|
)
|
||||||
@ -297,6 +304,7 @@ def make_test_quant_config(
|
|||||||
quant_dtype: torch.dtype | str | None = None,
|
quant_dtype: torch.dtype | str | None = None,
|
||||||
per_act_token_quant: bool = False,
|
per_act_token_quant: bool = False,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int] | None = None,
|
||||||
|
make_gate: bool = True,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||||
e,
|
e,
|
||||||
@ -306,6 +314,7 @@ def make_test_quant_config(
|
|||||||
quant_dtype,
|
quant_dtype,
|
||||||
per_out_ch_quant=per_act_token_quant,
|
per_out_ch_quant=per_act_token_quant,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
|
make_gate=make_gate,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hacky/trivial scales for nvfp4.
|
# Hacky/trivial scales for nvfp4.
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from torch._prims_common import TensorLikeType
|
|||||||
|
|
||||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||||
@ -839,6 +840,7 @@ def torch_experts(
|
|||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int] | None = None,
|
||||||
apply_router_weights_on_input: bool = False,
|
apply_router_weights_on_input: bool = False,
|
||||||
|
activation: str = "silu_and_mul",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert (
|
assert (
|
||||||
global_num_experts == -1
|
global_num_experts == -1
|
||||||
@ -881,6 +883,8 @@ def torch_experts(
|
|||||||
|
|
||||||
f32 = torch.float32
|
f32 = torch.float32
|
||||||
|
|
||||||
|
act = CustomOp.op_registry[activation]
|
||||||
|
|
||||||
for i in range(num_experts):
|
for i in range(num_experts):
|
||||||
mask = topk_ids == i
|
mask = topk_ids == i
|
||||||
if mask.sum():
|
if mask.sum():
|
||||||
@ -888,7 +892,7 @@ def torch_experts(
|
|||||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||||
if b_bias1 is not None:
|
if b_bias1 is not None:
|
||||||
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
||||||
tmp2 = SiluAndMul()(tmp1)
|
tmp2 = act()(tmp1)
|
||||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||||
if b_bias2 is not None:
|
if b_bias2 is not None:
|
||||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
|
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
|
||||||
@ -969,6 +973,7 @@ def torch_moe(
|
|||||||
b_bias2: torch.Tensor | None = None,
|
b_bias2: torch.Tensor | None = None,
|
||||||
global_num_experts: int = -1,
|
global_num_experts: int = -1,
|
||||||
expert_map: torch.Tensor | None = None,
|
expert_map: torch.Tensor | None = None,
|
||||||
|
activation: str = "silu_and_mul",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||||
topk_weight, topk_ids = torch.topk(score, topk)
|
topk_weight, topk_ids = torch.topk(score, topk)
|
||||||
@ -982,6 +987,7 @@ def torch_moe(
|
|||||||
b_bias1,
|
b_bias1,
|
||||||
b_bias2,
|
b_bias2,
|
||||||
expert_map,
|
expert_map,
|
||||||
|
activation=activation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -600,14 +600,20 @@ class FusedMoE(CustomOp):
|
|||||||
# Avoid circular import
|
# Avoid circular import
|
||||||
from vllm.model_executor.layers.quantization.modelopt import (
|
from vllm.model_executor.layers.quantization.modelopt import (
|
||||||
ModelOptFp8MoEMethod,
|
ModelOptFp8MoEMethod,
|
||||||
|
ModelOptNvFp4FusedMoE,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
|
self.quant_method,
|
||||||
|
(
|
||||||
|
UnquantizedFusedMoEMethod,
|
||||||
|
ModelOptFp8MoEMethod,
|
||||||
|
ModelOptNvFp4FusedMoE,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"is_act_and_mul=False is supported only for unquantized "
|
"is_act_and_mul=False is supported only for unquantized "
|
||||||
"and ModelOpt FP8 moe for now"
|
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
|
||||||
)
|
)
|
||||||
if not current_platform.is_cuda():
|
if not current_platform.is_cuda():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -1277,7 +1283,7 @@ class FusedMoE(CustomOp):
|
|||||||
self._load_combined_w13_weight_scale(
|
self._load_combined_w13_weight_scale(
|
||||||
shard_dim=shard_dim,
|
shard_dim=shard_dim,
|
||||||
loaded_weight=loaded_weight,
|
loaded_weight=loaded_weight,
|
||||||
param=param,
|
param=expert_data,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
)
|
)
|
||||||
return True if return_success else None
|
return True if return_success else None
|
||||||
|
|||||||
@ -1216,7 +1216,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
w13_weight = ModelWeightParameter(
|
w13_weight = ModelWeightParameter(
|
||||||
data=torch.empty(
|
data=torch.empty(
|
||||||
num_experts,
|
num_experts,
|
||||||
2 * intermediate_size_per_partition,
|
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
|
||||||
# 2 fp4 items are packed in the input dimension
|
# 2 fp4 items are packed in the input dimension
|
||||||
hidden_size // 2,
|
hidden_size // 2,
|
||||||
dtype=weight_dtype,
|
dtype=weight_dtype,
|
||||||
@ -1245,7 +1245,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
w13_weight_scale = ModelWeightParameter(
|
w13_weight_scale = ModelWeightParameter(
|
||||||
data=torch.empty(
|
data=torch.empty(
|
||||||
num_experts,
|
num_experts,
|
||||||
2 * intermediate_size_per_partition,
|
(2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition,
|
||||||
# 2 fp4 items are packed in the input dimension
|
# 2 fp4 items are packed in the input dimension
|
||||||
hidden_size // self.quant_config.group_size,
|
hidden_size // self.quant_config.group_size,
|
||||||
dtype=weight_scale_dtype,
|
dtype=weight_scale_dtype,
|
||||||
@ -1275,7 +1275,9 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
w13_weight_scale_2 = PerTensorScaleParameter(
|
w13_weight_scale_2 = PerTensorScaleParameter(
|
||||||
data=torch.empty(num_experts, 2, dtype=torch.float32),
|
data=torch.empty(
|
||||||
|
num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32
|
||||||
|
),
|
||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
|
||||||
@ -1296,7 +1298,11 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
|
global_scale_num_experts = global_num_experts if use_global_sf else num_experts
|
||||||
|
|
||||||
w13_input_scale = PerTensorScaleParameter(
|
w13_input_scale = PerTensorScaleParameter(
|
||||||
data=torch.empty(global_scale_num_experts, 2, dtype=torch.float32),
|
data=torch.empty(
|
||||||
|
global_scale_num_experts,
|
||||||
|
2 if self.moe.is_act_and_mul else 1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
),
|
||||||
weight_loader=weight_loader,
|
weight_loader=weight_loader,
|
||||||
)
|
)
|
||||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
@ -1312,9 +1318,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
gemm1_weight = layer.w13_weight.data
|
gemm1_weight = layer.w13_weight.data
|
||||||
gemm1_weight_scale = layer.w13_weight_scale.data
|
gemm1_weight_scale = layer.w13_weight_scale.data
|
||||||
|
|
||||||
if self.allow_flashinfer and (
|
if (
|
||||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
self.allow_flashinfer
|
||||||
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
and (
|
||||||
|
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||||
|
or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||||
|
)
|
||||||
|
and self.moe.is_act_and_mul
|
||||||
):
|
):
|
||||||
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1(
|
||||||
gemm1_weight, gemm1_weight_scale, dim=-2
|
gemm1_weight, gemm1_weight_scale, dim=-2
|
||||||
@ -1324,7 +1334,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
|
layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False)
|
||||||
|
|
||||||
# Common processing for w13_weight_scale_2
|
# Common processing for w13_weight_scale_2
|
||||||
if not torch.allclose(
|
if self.moe.is_act_and_mul and not torch.allclose(
|
||||||
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
|
||||||
):
|
):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@ -1437,11 +1447,39 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
w13_blockscale_swizzled, requires_grad=False
|
w13_blockscale_swizzled, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
w13_weight = layer.w13_weight
|
||||||
|
intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1)
|
||||||
|
if intermediate_size_pad:
|
||||||
|
# padding gated activations will require to split w1 and w3
|
||||||
|
# and pad them individually
|
||||||
|
assert not self.moe.is_act_and_mul, (
|
||||||
|
"The intermediate size required padding, "
|
||||||
|
"but padding is not implemented for gated activations"
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.w13_weight = Parameter(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
w13_weight, (0, 0, 0, intermediate_size_pad)
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_weight = Parameter(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0)
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_weight_scale = Parameter(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
layer.w2_weight_scale, (0, intermediate_size_pad // 16)
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale)
|
||||||
layer.w2_weight_scale = Parameter(
|
layer.w2_weight_scale = Parameter(
|
||||||
w2_blockscale_swizzled, requires_grad=False
|
w2_blockscale_swizzled, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
|
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
@ -1484,7 +1522,14 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: torch.Tensor | None = None,
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
logical_replica_count: torch.Tensor | None = None,
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert activation == "silu", "Only SiLU activation is supported."
|
if not self.moe.is_act_and_mul:
|
||||||
|
assert (
|
||||||
|
self.allow_flashinfer
|
||||||
|
and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||||
|
), (
|
||||||
|
"Non-gated activations are only supported by the"
|
||||||
|
" flashinfer CUTLASS backend for modelopt checkpoints"
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.allow_flashinfer
|
self.allow_flashinfer
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user