diff --git a/tests/kernels/test_block_fp8.py b/tests/kernels/test_block_fp8.py index c4488c0c6ff3..347319b303f4 100644 --- a/tests/kernels/test_block_fp8.py +++ b/tests/kernels/test_block_fp8.py @@ -360,7 +360,7 @@ def fp8_perm(m, idx): return m[idx, ...] -def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): +def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): M, K = a.shape sorted_token_ids, m_indices, num_pad = moe_align_block_size( @@ -379,7 +379,7 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): return a, a_s, m_indices, inv_perm -def test_moe_unpermute(out, inv_perm, topk, K, topk_weight): +def _moe_unpermute(out, inv_perm, topk, K, topk_weight): M = topk_weight.shape[0] out = out[inv_perm, ...] tmp_out = out.view(-1, topk, K) @@ -401,8 +401,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, a_q, a_s = per_token_group_quant_fp8(a, block_m) - a_q, a_s, m_indices, inv_perm = test_moe_permute(a_q, a_s, topk_ids, - num_groups, topk, block_m) + a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids, + num_groups, topk, block_m) inter_out = torch.zeros((a_q.shape[0], N * 2), dtype=torch.bfloat16, @@ -419,7 +419,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (act_out_q, act_out_s), (w2, w2_s), out, m_indices) - final_out = test_moe_unpermute(out, inv_perm, topk, K, topk_weight) + final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) return final_out