diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index a3102d0cb8806..868bc8350bf11 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -268,7 +268,7 @@ def batched_moe( block_shape: Optional[list[int]] = None, per_act_token: bool = False, ) -> torch.Tensor: - max_num_tokens = round_up(a.shape[0], 64) # ? + max_num_tokens = round_up(a.shape[0], 64) fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens, world_size=1, @@ -342,9 +342,9 @@ def torch_moe2( topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222]) +@pytest.mark.parametrize("n", [128, 512, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024, 2048]) @pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16]) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 44fb82a5f427a..f667864ca03f3 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -611,15 +611,12 @@ def pplx_moe( num_experts = w1.shape[0] block_size = block_shape[1] if block_shape is not None else 128 topk = topk_ids.shape[1] - if block_shape: - max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), block_shape[0]) - else: - max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64) if qtype is not None: a_dtype = qtype - #print(f"SCALE BYTES {hidden_dim} {block_size} {((hidden_dim + block_size - 1) * torch.float32.itemsize) // block_size}") - scale_bytes = 16 + # This is probably not right + scale_bytes = round_up(((hidden_dim + block_size - 1) // block_size) * torch.float32.itemsize, 16) else: a_dtype = a.dtype scale_bytes = 0 diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index f0fea396cc8de..624e948a67395 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input) +from vllm.utils import round_up @triton.jit @@ -336,7 +337,7 @@ def invoke_moe_batched_triton_kernel( BLOCK_K = config['BLOCK_SIZE_K'] assert (torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing() - or max_num_tokens % BLOCK_M == 0) + or max_num_tokens % BLOCK_M == 0), f"{max_num_tokens} {BLOCK_M}" grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) @@ -666,7 +667,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, - max_num_tokens: Optional[int] = None, + max_num_tokens: int, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -682,13 +683,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape - self.max_num_tokens = max_num_tokens assert not use_int8_w8a8, "NYI" assert not use_int4_w4a16, "NYI" self.world_size = world_size self.dp_size = dp_size self.per_act_token = per_act_token self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None + self.max_num_tokens = max_num_tokens def workspace_shapes( self, @@ -701,10 +702,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ) -> tuple[int, int, torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) - workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) + workspace13 = num_experts * self.max_num_tokens * num_dp * max(K, N) + workspace2 = num_experts * self.max_num_tokens * num_dp * (N // 2) return (workspace13, workspace2, a.dtype) def apply( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ff6ecffc7663a..8eef20c75c432 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,7 +29,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, cdiv has_pplx = importlib.util.find_spec("pplx_kernels") is not None @@ -268,6 +268,15 @@ class FusedMoEMethodBase(QuantizeMethodBase): prepare_finalize = None if moe.use_pplx_kernels: + # For blocked per token: set to + # ceil_div(hidden_dim, block_size) * sizeof(float32) + # For per-token: set to sizeof(float32) + if moe.quant_dtype.itemsize == 1: + scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) * + torch.float32.itemsize) + else: + scale_bytes = 0 + all_to_all_args = dict( max_num_tokens=moe.max_num_tokens, num_experts=moe.num_experts, @@ -278,12 +287,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, - # For blocked per token: set to - # ceil_div(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.quant_dtype.itemsize != 1 else ( - ((moe.hidden_dim + moe.block_size - 1) // moe.block_size) * - torch.float32.itemsize)), + hidden_dim_scale_bytes=scale_bytes, ) if not all2all_manager.internode: diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 589629dbfe243..c9d24e16806d8 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -94,7 +94,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): float32_size = torch.float32.itemsize block_size = (self.block_shape[0] if self.block_shape is not None else 1) * float32_size - expert_x_scale = torch.empty( + expert_x_scale = torch.zeros( ( num_experts, expert_x.size(1),