diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/deepep_utils.py index 2bc9b657da859..117f1babdf62a 100644 --- a/tests/kernels/moe/deepep_utils.py +++ b/tests/kernels/moe/deepep_utils.py @@ -162,12 +162,14 @@ def make_deepep_ll_a2a(pg: ProcessGroup, low_latency_mode=True, num_qps_per_rank=deepep_ll_args.num_experts // pgi.world_size) + return DeepEPLLPrepareAndFinalize( buffer=buffer, world_size=pgi.world_size, dp_size=dp_size, max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, quant_dtype=q_dtype, + block_shape=block_shape, use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, ) @@ -185,4 +187,5 @@ def make_deepep_a2a(pg: ProcessGroup, block_shape) assert deepep_ll_args is not None - return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype) + return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype, + block_shape) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index a1fdc1d5ff47b..2d7cf39a8cca5 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """ Test DeepEP + DeepGEMM integration +DeepGEMM are gemm kernels specialized for the +fp8 block-quantized case. """ import dataclasses @@ -33,10 +35,14 @@ except ImportError: if has_deep_ep: from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) - from .deepep_utils import DeepEPHTArgs, make_deepep_a2a + from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a if has_deep_gemm: + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( DeepGemmExperts) @@ -53,6 +59,13 @@ requires_deep_gemm = pytest.mark.skipif( P = ParamSpec("P") +def next_power_of_2(x): + import math + if x == 0: + return 1 + return 2**math.ceil(math.log2(x)) + + def per_block_cast_to_fp8( x: torch.Tensor, block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: @@ -126,6 +139,9 @@ class TestConfig: n: int num_experts: int block_size: list[int] + # configs for testing low-latency kernels + low_latency: bool + use_fp8_dispatch: Optional[bool] = False @dataclasses.dataclass @@ -170,9 +186,43 @@ class TestTensors: config=config) -def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - num_local_experts: int, q_dtype: Optional[torch.dtype], - block_shape: list[int]) -> FusedMoEModularKernel: +def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, + max_tokens_per_rank: int, dp_size: int, + hidden_size: int, q_dtype: Optional[torch.dtype], + test_config: TestConfig) -> FusedMoEModularKernel: + + assert test_config.low_latency + assert test_config.use_fp8_dispatch is not None + + a2a: DeepEPLLPrepareAndFinalize = make_deepep_a2a( + pg=pg, + pgi=pgi, + dp_size=dp_size, + deepep_ht_args=None, + deepep_ll_args=DeepEPLLArgs( + max_tokens_per_rank=max_tokens_per_rank, + hidden_size=hidden_size, + num_experts=test_config.num_experts, + use_fp8_dispatch=test_config.use_fp8_dispatch), + q_dtype=q_dtype, + block_shape=test_config.block_size) + + fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank, + world_size=pgi.world_size, + dp_size=dp_size, + block_shape=test_config.block_size) + mk = FusedMoEModularKernel(prepare_finalize=a2a, + fused_experts=fused_experts) + return mk + + +def make_ht_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, + dp_size: int, num_local_experts: int, + q_dtype: Optional[torch.dtype], + test_config: TestConfig) -> FusedMoEModularKernel: + + assert not test_config.low_latency + assert test_config.use_fp8_dispatch is None a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a( pg=pg, @@ -181,7 +231,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts), deepep_ll_args=None, q_dtype=q_dtype, - block_shape=block_shape) + block_shape=test_config.block_size) fused_experts = DeepGemmExperts() mk = FusedMoEModularKernel(prepare_finalize=a2a, @@ -189,12 +239,42 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, return mk -def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, - test_tensors: TestTensors, w1: torch.Tensor, - w2: torch.Tensor, w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - num_experts: int) -> torch.Tensor: +def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + num_local_experts: int, + test_tensors: TestTensors) -> FusedMoEModularKernel: + q_dtype = torch.float8_e4m3fn + test_config = test_tensors.config + + mk: FusedMoEModularKernel + # Make modular kernel + if test_config.low_latency: + max_tokens_per_rank = max( + 64, next_power_of_2(test_tensors.rank_tokens.size(0))) + hidden_size = test_tensors.rank_tokens.size(-1) + + mk = make_ll_modular_kernel(pg=pg, + pgi=pgi, + max_tokens_per_rank=max_tokens_per_rank, + dp_size=dp_size, + hidden_size=hidden_size, + q_dtype=q_dtype, + test_config=test_config) + else: + mk = make_ht_modular_kernel(pg, pgi, dp_size, num_local_experts, + q_dtype, test_config) + + return mk + + +def deepep_deepgemm_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, + dp_size: int, test_tensors: TestTensors, + w1: torch.Tensor, w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor]) -> torch.Tensor: + + test_config = test_tensors.config + num_experts = test_config.num_experts num_local_experts = w1.size(0) def build_expert_map(): @@ -208,14 +288,17 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, return expert_map.to(device=torch.cuda.current_device(), dtype=torch.int32) - q_dtype = torch.float8_e4m3fn - # Make modular kernel mk: FusedMoEModularKernel = make_modular_kernel( - pg, pgi, dp_size, num_local_experts, q_dtype, - test_tensors.config.block_size) + pg=pg, + pgi=pgi, + dp_size=dp_size, + num_local_experts=num_local_experts, + test_tensors=test_tensors) - a1_scale = test_tensors.rank_token_scales + # Low-Latency kernels can't dispatch scales. + a1_scale = (None + if test_config.low_latency else test_tensors.rank_token_scales) out = mk.forward(hidden_states=test_tensors.rank_tokens, w1=w1, @@ -258,7 +341,7 @@ def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, allow_deep_gemm=False) -def _deep_ep_moe( +def _test_deepep_deepgemm_moe( pgi: ProcessGroupInfo, dp_size: int, config: TestConfig, @@ -302,7 +385,7 @@ def _deep_ep_moe( w1_scale_ep = w1_scale[e_start:e_end] w2_scale_ep = w2_scale[e_start:e_end] - deepep_moe = deep_ep_moe_impl( + deepep_moe = deepep_deepgemm_moe_impl( pg, pgi, dp_size, @@ -311,7 +394,6 @@ def _deep_ep_moe( w2_ep, w1_scale_ep, w2_scale_ep, - config.num_experts, ) torch.testing.assert_close( @@ -335,15 +417,21 @@ MNKs = [ (222, 1024, 2048), ] +TOPKS = [2, 6] +NUM_EXPERTS = [32] + @pytest.mark.parametrize("mnk", MNKs) -@pytest.mark.parametrize("num_experts", [32]) -@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOPKS) @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm -def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, - world_dp_size: tuple[int, int]): +def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, + topk: int, world_dp_size: tuple[int, int]): + """ + Tests for High-Throughput DeepEP + DeepGemm integration. + """ m, n, k = mnk current_platform.seed_everything(7) @@ -354,6 +442,58 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, block_m = deep_gemm.get_m_alignment_for_contiguous_layout() block_size = [block_m, block_m] + world_size, dp_size = world_dp_size + config = TestConfig(topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts, + block_size=block_size, + low_latency=False, + use_fp8_dispatch=None) + + w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + num_experts, n, k, block_size) + + parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, + w2, w1_scale, w2_scale) + + +MNKs = [ + (1, 128, 2560), + (2, 128, 2560), + (3, 1024, 2560), + (32, 128, 2560), + (45, 512, 2560), + (64, 1024, 2560), + (222, 1024, 2560), +] +# Fix tests for USE_FP8_DISPATCH=True +USE_FP8_DISPATCH = [False] + + +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOPKS) +@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) +@pytest.mark.parametrize("block_size", [[128, 128]]) +@pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@requires_deep_ep +@requires_deep_gemm +def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int, + int], num_experts: int, topk: int, + use_fp8_dispatch: bool, block_size: list[int], + world_dp_size: tuple[int, int]): + """ + Tests for Low-Latency DeepEP + DeepGemm integration. + """ + + m, n, k = mnk + current_platform.seed_everything(7) + + if topk > num_experts: + pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") + world_size, dp_size = world_dp_size config = TestConfig( topk=topk, @@ -362,10 +502,12 @@ def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, n=n, num_experts=num_experts, block_size=block_size, + low_latency=True, + use_fp8_dispatch=use_fp8_dispatch, ) w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( num_experts, n, k, block_size) - parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2, - w1_scale, w2_scale) + parallel_launch(world_size, _test_deepep_deepgemm_moe, dp_size, config, w1, + w2, w1_scale, w2_scale) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py new file mode 100644 index 0000000000000..a541d46b14a92 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +import importlib.util +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, per_token_group_quant_fp8) + +logger = init_logger(__name__) + +has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None + + +class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + # The Deep Gemm kernels only support block size of 128 + DEEPGEMM_BLOCK_SHAPE = 128 + + def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, + block_shape: list[int]): + """ + max_num_tokens: Maximum number of tokens from a DP Rank + world_size: Number of EP ranks + dp_size: Number of data-parallel ranks + block_shape: Block quantization block shape + """ + super().__init__() + self.max_num_tokens = max_num_tokens + self.world_size = world_size + self.dp_size = dp_size + self.block_shape = block_shape + + assert (len(self.block_shape) == 2 and all( + [v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + assert a.dim() == 2 + num_dp = self.world_size // self.dp_size + max_num_tokens = a.size( + 0) if self.max_num_tokens is None else self.max_num_tokens + workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) + workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) + return (workspace13, workspace2, a.dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + import deep_gemm as dg + assert hidden_states.ndim == 3 + + a1q = hidden_states + _, N, K = w1.size() + + if global_num_experts == -1: + global_num_experts = w1.size(0) + + assert w2.size(1) == K + + E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( + hidden_states, w1, w2, topk_ids) + + workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) + workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) + workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K)) + + # (from deepgemm docs) : A value hint (which is a value on CPU) + # for the M expectation of each batch, correctly setting this value + # may lead to better performance. + expected_m = max_num_tokens + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale), + (w1, w1_scale), + out=workspace1, + masked_m=expert_num_tokens, + expected_m=expected_m) + + # TODO (varun) [Optimization]: Use a batched version of activation. + # Similarly for the quant below. + self.activation(activation, workspace2, workspace1.view(-1, N)) + + w2_hidden_size = workspace2.size(-1) + workspace2 = workspace2.view(-1, w2_hidden_size) + + a2q_scale: Optional[torch.Tensor] = None + a2q, a2q_scale = per_token_group_quant_fp8(workspace2, + self.block_shape[1], + column_major_scales=False) + a2q = a2q.view(E, max_num_tokens, -1) + a2q_scale = a2q_scale.view(E, max_num_tokens, -1) + + dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), + (w2, w2_scale), + out=workspace3, + masked_m=expert_num_tokens, + expected_m=expected_m) + + return workspace3 diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py new file mode 100644 index 0000000000000..4db6b84e9d5bd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) + + +class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__(self, + max_num_tokens: int, + world_size: int, + dp_size: int, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + allow_deep_gemm: bool = False): + super().__init__() + assert not use_int8_w8a8, "NYI" + assert not use_int8_w8a16, "NYI" + assert not use_int4_w4a16, "NYI" + + self.max_num_tokens = max_num_tokens + self.world_size = world_size + self.dp_size = dp_size + self.use_fp8_w8a8 = use_fp8_w8a8 + self.use_int8_w8a8 = use_int8_w8a8 + self.use_int8_w8a16 = use_int8_w8a16 + self.use_int4_w4a16 = use_int4_w4a16 + self.per_channel_quant = per_channel_quant + self.block_shape = block_shape + self.allow_deep_gemm = allow_deep_gemm + + # BatchedTritonKernel doesn't support block quantization + # at the moment. + self.batched_triton_experts = BatchedTritonExperts( + max_num_tokens=self.max_num_tokens, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape, + world_size=self.world_size, + dp_size=self.dp_size) if self.block_shape is None else None + + is_fp8_128_block_quantized = (self.use_fp8_w8a8 + and self.block_shape is not None + and len(self.block_shape) == 2 and all( + [b == 128 + for b in self.block_shape])) + self.batched_deep_gemm_experts = BatchedDeepGemmExperts( + max_num_tokens=self.max_num_tokens, + world_size=self.world_size, + dp_size=self.dp_size, + block_shape=self.block_shape, # type: ignore[arg-type] + ) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None + + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + # Note: the deep gemm workspaces are strictly larger than the triton + # workspaces so we can be pessimistic here and allocate for DeepGemm + # even if we fall back to triton later, e.g. if expert maps are set. + if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: + return self.batched_deep_gemm_experts.workspace_shapes( + a, M, N, K, topk, num_experts) + else: + assert self.batched_triton_experts is not None + return self.batched_triton_experts.workspace_shapes( + a, M, N, K, topk, num_experts) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + use_batched_deep_gemm_experts = (self.allow_deep_gemm + and self.batched_deep_gemm_experts + is not None) + experts = (self.batched_deep_gemm_experts + if use_batched_deep_gemm_experts else + self.batched_triton_experts) + assert experts is not None + return experts.apply(hidden_states, w1, w2, topk_ids, activation, + global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, + workspace13, workspace2, expert_num_tokens) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b9d817a14d57e..3484a7a8a496a 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional +from typing import Optional, Union import deep_ep import torch @@ -65,6 +65,54 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def topk_indices_dtype(self) -> Optional[torch.dtype]: return torch.int64 + def _do_quant( + self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor], + a1_dtype: torch.dtype + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + + block_k = self.block_shape[1] if self.block_shape is not None else None + if self.use_fp8_dispatch: + if block_k == DEEPEP_QUANT_BLOCK_SIZE: + # DeepEP kernels did the quantization for us. + x, x_scales = x + return x, x_scales + + # Dequant to get back the tokens in the datatype we dispatched in. + x_fp8, x_scales = x + x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) + + assert isinstance(x, torch.Tensor) + + # Check if there is a block_shape / or if we can infer the quantization + # schemes from the scales. + per_token_quant = None + if all([v is None for v in [self.block_shape, a1_scale, a2_scale] + ]) and self.quant_dtype is not None: + # Quantization required despite none of the inputs suggesting + # quantization. Fallback to per_token_dynamic quant. + per_token_quant = True + else: + per_token_quant = ((self.block_shape is not None) or + (a1_scale is not None and a1_scale.numel() != 1) + or (a2_scale is not None + and a2_scale.numel() != 1)) + + num_experts, max_tokens, hidden_dim = x.size() + + # TODO (varun): Optimization - Use a batched version of quant + x = x.view((-1, hidden_dim)) + x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype, + per_token_quant, + self.block_shape) + x = x.view((num_experts, -1, hidden_dim)) + + if per_token_quant: + assert x_scales is not None + x_scales = x_scales.view(num_experts, max_tokens, -1) + + return x, x_scales + def prepare( self, a1: torch.Tensor, @@ -87,11 +135,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): assert hidden_size % 128 == 0, \ "DeepEP kernels quantize the inputs in blocks of shape 128" - # Quantize - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + has_per_token_scales = a1_scale.numel( + ) != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) - assert not per_act_token, ( - "low_latency kernels don't support per-act-token quant") + assert not has_per_token_scales, ( + "low_latency kernels doesn't support dispatching per-token scales") if apply_router_weight_on_input: topk = rank_topk_ids.size(1) @@ -110,22 +158,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): async_finish=False, return_recv_hook=False) - if self.use_fp8_dispatch: - # TODO (varun) : In the case of dynamic quantization, we could - # probably skip the quant below and use the results directly. - # Although note that the deepep quant is per token 128 elements. - expert_x_fp8, expert_x_scales = expert_x - expert_x = dequant_fp8(expert_x_fp8, - expert_x_scales).to(dtype=a1.dtype) - - num_experts = expert_x.size(0) - hidden_dim = expert_x.size(-1) - - expert_x = expert_x.view((-1, expert_x.size(-1))) - expert_x, expert_x_scale = moe_kernel_quantize_input( - expert_x, a1_scale, self.quant_dtype, per_act_token, - self.block_shape) - expert_x = expert_x.view((num_experts, -1, hidden_dim)) + expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, + a1.dtype) return (expert_x, expert_x_scale, expert_num_tokens, None, None) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 2438ec30bdd2b..5ac22b6a0aee4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -771,21 +771,21 @@ class Fp8MoEMethod(FusedMoEMethodBase): def select_gemm_impl(self, prepare_finalize): - from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( - BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") - experts: Optional[Union[BatchedTritonExperts, + experts: Optional[Union[BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts]] = None max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() use_batched_experts = max_num_tokens_per_rank is not None if use_batched_experts: - experts = BatchedTritonExperts( + experts = BatchedTritonOrDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, world_size=prepare_finalize.world_size, dp_size=prepare_finalize.dp_size, @@ -793,7 +793,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, - block_shape=None, + per_channel_quant=False, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, ) else: experts = TritonOrDeepGemmExperts(