From 468d16654ab1eb3883ed79c78042d9edc6461baa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 28 May 2025 20:19:41 +0000 Subject: [PATCH] cleanup quantization Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 15 ++- .../layers/fused_moe/fused_batched_moe.py | 118 +++++++++++------- .../layers/fused_moe/fused_moe.py | 10 +- vllm/model_executor/layers/fused_moe/layer.py | 25 +--- .../model_executor/layers/quantization/fp8.py | 6 +- 5 files changed, 96 insertions(+), 78 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index 2c3e32ca1d7c8..31991d4e680f5 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -270,8 +270,9 @@ def batched_moe( topk_ids: torch.Tensor, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, - use_fp8_w8a8: bool = False, + qtype: Optional[torch.dtype] = None, block_shape: Optional[list[int]] = None, + per_act_token: bool = False, ) -> torch.Tensor: max_num_tokens = round_up(a.shape[0], 64) # ? fused_experts = FusedMoEModularKernel( @@ -279,12 +280,13 @@ def batched_moe( world_size=1, dp_size=1, rank=0, - use_fp8_w8a8=use_fp8_w8a8, - block_shape=block_shape), + qtype=qtype, + block_shape=block_shape, + per_act_token=False), BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1, - use_fp8_w8a8=use_fp8_w8a8, + use_fp8_w8a8=qtype == torch.float8_e4m3fn, block_shape=block_shape)) return fused_experts(a, @@ -360,7 +362,7 @@ def torch_moe2( @pytest.mark.parametrize("k", [128, 512, 1024]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) def test_fused_moe_batched_experts( m: int, n: int, @@ -378,6 +380,7 @@ def test_fused_moe_batched_experts( score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16) use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn + qtype = dtype if dtype == torch.torch.float8_e4m3fn else None if use_fp8_w8a8: block_n, block_k = block_shape[0], block_shape[1] @@ -409,7 +412,7 @@ def test_fused_moe_batched_experts( baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape) batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, - w2_s, use_fp8_w8a8, block_shape) + w2_s, qtype, block_shape) torch.testing.assert_close(baseline_output, batched_output, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index b1289ae0f53cd..8c575958b5b1b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -9,9 +9,9 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, + moe_kernel_quantize_input) @triton.jit @@ -47,6 +47,7 @@ def moe_mmk( compute_type: tl.constexpr, use_w8a8: tl.constexpr, use_w8a16: tl.constexpr): + offs_k = tl.arange(0, BLOCK_K) if use_w8a16: @@ -325,6 +326,7 @@ def invoke_moe_batched_triton_kernel( use_int4_w4a16: bool, config: dict[str, int], block_shape: Optional[list[int]] = None): + assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2) @@ -393,15 +395,17 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): world_size: int, dp_size: int, rank: int, - use_fp8_w8a8: bool = False, + qtype: Optional[torch.dtype] = None, + per_act_token: bool = False, block_shape: Optional[list[int]] = None): super().__init__() self.world_size = world_size self.dp_size = dp_size self.rank = rank self.max_num_tokens = max_num_tokens - self.use_fp8_w8a8 = use_fp8_w8a8 + self.per_act_token = per_act_token self.block_shape = block_shape + self.qtype = qtype def prepare( self, @@ -445,10 +449,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): b_a1 = torch.zeros( (num_local_experts, self.max_num_tokens, hidden_dim), - dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else a1.dtype, + dtype=self.qtype if self.qtype is not None else a1.dtype, device=a1.device) - if self.use_fp8_w8a8: + if self.qtype is not None: k_tiles = (hidden_dim + block_k - 1) // block_k b_a1_scale = torch.zeros( (num_local_experts, self.max_num_tokens, k_tiles), @@ -465,10 +469,20 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): rows = torch.count_nonzero(topks.flatten()) rhs = a1[:topks.numel()][topks] idx = expert_id - first_expert - if self.use_fp8_w8a8: - # TODO: use _fp8_quantize - b_a1[idx, :rows, :], b_a1_scale[ - idx, :rows] = per_token_group_quant_fp8(rhs, block_k) + if self.qtype is not None: + if a1_scale is not None: + rhs_a1_scale = a1_scale[:topks.numel()][topks] + else: + rhs_a1_scale = None + b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = ( + moe_kernel_quantize_input( + rhs, + rhs_a1_scale, + self.qtype, + self.per_act_token, + self.block_shape, + ) + ) else: b_a1[idx, :rows, :] = rhs @@ -524,7 +538,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): block_m: Optional[int] = None, ): super().__init__() - #assert block_shape is None assert block_m is None assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" @@ -615,6 +628,42 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): return out +def batched_moe_kernel_quantize_input( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + num_tokens: int, + E: int, + N: int, + expert_num_tokens: torch.Tensor, + qtype: Optional[torch.dtype], + per_channel_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if qtype is not None: + assert block_shape is not None + A_q = torch.empty_like(A, dtype=qtype) + block_n, block_k = block_shape + n_tiles = ((N // 2) + block_n - 1) // block_n + scale_shape = (E, num_tokens, n_tiles) + A_q_scale = torch.empty(scale_shape, + dtype=torch.float32, + device=A.device) + for e in range(E): + num_tokens = expert_num_tokens[e] + if num_tokens > 0: + A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input( + A[e, :num_tokens], + A_scale[e, :num_tokens] if A_scale else None, + qtype, + per_channel_quant, + [block_k, block_n]) + A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale + + return A_q, A_q_scale + else: + return A, A_scale + + class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): """ A Triton based MoE expert class that operates on expert batched format, @@ -630,6 +679,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, block_shape: Optional[list[int]] = None, + per_act_token: bool = False, world_size: int = 1, dp_size: int = 1, ): @@ -644,6 +694,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): assert not use_int4_w4a16, "NYI" self.world_size = world_size self.dp_size = dp_size + self.per_act_token = per_act_token + self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None def workspace_shapes( self, @@ -731,7 +783,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) @@ -761,36 +812,17 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self.activation(activation, intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N)) - #qintermediate_cache2 = intermediate_cache2 - - # TODO (varun) : support w8a8 - #assert not self.use_fp8_w8a8 - if self.use_fp8_w8a8: - per_act_token = False - # TODO: reuse? - qintermediate_cache2 = torch.empty_like(intermediate_cache2, - dtype=torch.float8_e4m3fn) - block_n = self.block_shape[0] - n_tiles = ((N // 2) + block_n - 1) // block_n - scale_shape = (E, num_tokens, n_tiles) - a2q_scale = torch.empty(scale_shape, - dtype=torch.float32, - device=hidden_states.device) - for e in range(E): - num_tokens = expert_num_tokens[e] - if num_tokens > 0: - #qintermediate_cache2[e], tmp_scale = _fp8_quantize( - # intermediate_cache2[e], - # a2_scale[e] if a2_scale is not None else None, - # per_act_token, self.block_shape) - qintermediate_cache2[ - e, : - num_tokens, :], tmp_scale = per_token_group_quant_fp8( - intermediate_cache2[e, :num_tokens], block_n) - a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale - else: - qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale + qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input( + intermediate_cache2, + a2_scale, + num_tokens, + E, + N, + expert_num_tokens, + self.qtype, + self.per_act_token, + self.block_shape + ) invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 78f8eb926dc83..b32a54b338445 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1520,11 +1520,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + per_channel_quant: bool = False, block_shape: Optional[list[int]] = None, block_m: Optional[int] = None, ): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 11210aeaebaf5..c03217b790c5e 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -192,7 +192,7 @@ class MoEConfig: num_local_experts: int moe_parallel_config: FusedMoEParallelConfig - in_dtype: torch.dtype # The activation type. + in_dtype: torch.dtype # The post quantization activation type. # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -489,22 +489,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): max_num_tokens=MOE_DP_CHUNK_SIZE, world_size=world_size, dp_size=dp_size, - use_fp8_w8a8=False, #moe.in_dtype == torch.float8_e4m3fn, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, ) else: logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) + experts = TritonExperts() self.fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -827,8 +815,7 @@ class FusedMoE(torch.nn.Module): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - logger.debug(f"PARAM DTYPE = {params_dtype}") - #assert params_dtype.itemsize == 1 + logger.debug("Model dtype = %s", vllm_config.model_config.dtype) moe = MoEConfig( max_num_tokens=MOE_DP_CHUNK_SIZE, @@ -838,7 +825,6 @@ class FusedMoE(torch.nn.Module): num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, in_dtype=moe.in_dtype, - max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe self.quant_config = quant_config @@ -877,15 +863,14 @@ class FusedMoE(torch.nn.Module): self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None if self.moe_parallel_config.use_pplx_kernels: - act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( (MOE_DP_CHUNK_SIZE, self.hidden_size), - dtype=act_dtype, + dtype=moe.in_dtype, device=torch.cuda.current_device()) self.batched_router_logits = torch.zeros( (MOE_DP_CHUNK_SIZE, self.global_num_experts), - dtype=act_dtype, + dtype=moe.in_dtype, device=torch.cuda.current_device()) @property diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 7effd0e1ad24b..456c84496aba5 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -782,11 +782,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): max_num_tokens=MOE_DP_CHUNK_SIZE, world_size=world_size, dp_size=dp_size, - use_fp8_w8a8=True, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, + qtype=torch.float8_e4m3fn, block_shape=self.quant_config.weight_block_size, + per_act_token=False, #? ) else: logger.debug("TritonOrDeepGemmExperts(fp8)")