diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index f9a66d4d02ea..0ae992ad1110 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -154,6 +154,7 @@ def use_fused_moe_lora_kernel( "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "SPLIT_K": 1, } mul_routed_weight = False @@ -175,6 +176,7 @@ def use_fused_moe_lora_kernel( config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], config["GROUP_SIZE_M"], + config["SPLIT_K"], mul_routed_weight, ) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index d8746ebc8e75..2031ade64b5f 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -80,11 +80,13 @@ def _fused_moe_lora_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, ): pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) max_loras = tl.num_programs(axis=2) + grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) @@ -102,7 +104,7 @@ def _fused_moe_lora_kernel( # get the expert_id to process curr shard ind = lora_idx * stride_el + pid_m - expert_id = tl.load(expert_ids_ptr + ind) + expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) if expert_id == -1: return @@ -117,7 +119,7 @@ def _fused_moe_lora_kernel( offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) token_ind = stride_tl * lora_idx + offs_token_id offs_token = tl.load( - sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0.0 + sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 ) token_mask = offs_token < num_valid_tokens @@ -135,17 +137,18 @@ def _fused_moe_lora_kernel( # accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + for k in range(0, grid_k): + k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) a = tl.load( a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), other=0.0, ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) @@ -156,7 +159,10 @@ def _fused_moe_lora_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) + if SPLIT_K == 1: + tl.store(c_ptrs, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed") @torch.inference_mode() @@ -179,6 +185,7 @@ def _fused_moe_lora( block_size_n: int, block_size_k: int, group_size_m: int, + split_k: int, mul_routed_weight: bool = False, ) -> None: assert len(lora_a_stacked) == len(lora_b_stacked) > 0 @@ -206,6 +213,7 @@ def _fused_moe_lora( "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GROUP_SIZE_M": group_size_m, + "SPLIT_K": split_k, } w1_lora_a_stacked = lora_a_stacked[0] @@ -237,7 +245,9 @@ def _fused_moe_lora( b_ptr = _get_ptr(lora_a_stacked, device) grid = lambda META: ( - triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + split_k + * triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), lora_a_stacked[0].shape[0], ) @@ -286,6 +296,8 @@ def _fused_moe_lora( -1, a_intermediate_cache1.shape[3] ) + # Set split_k = 1 for expand calls + config["SPLIT_K"] = 1 grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index c2c26a01ee03..0cbf294cf410 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -385,5 +385,6 @@ class PunicaWrapperGPU(PunicaWrapperBase): config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], config["GROUP_SIZE_M"], + config.get("SPLIT_K", 1), mul_routed_weight, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 89e92edc8d2b..5f9bfd6d9cf7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -121,6 +121,7 @@ def fused_moe_kernel_gptq_awq( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, @@ -356,6 +357,7 @@ def fused_moe_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, top_k: tl.constexpr, compute_type: tl.constexpr, @@ -646,7 +648,6 @@ def invoke_fused_moe_kernel( bit, ) return - fused_moe_kernel_gptq_awq[grid]( A, B, @@ -686,6 +687,7 @@ def invoke_fused_moe_kernel( ) else: config = config.copy() + config["SPLIT_K"] = 1 BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") if block_shape is not None: BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) @@ -983,6 +985,7 @@ def get_default_config( "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, + "SPLIT_K": 1, } return config @@ -996,6 +999,7 @@ def get_default_config( "BLOCK_SIZE_N": block_shape[0], "BLOCK_SIZE_K": block_shape[1], "GROUP_SIZE_M": 32, + "SPLIT_K": 1, "num_warps": 4, "num_stages": 3 if not current_platform.is_rocm() else 2, } @@ -1006,19 +1010,20 @@ def get_default_config( bit = 4 if dtype == "int4_w4a16" else 8 use_moe_wna16_cuda = should_moe_wna16_use_cuda(M * topk, block_shape[1], E, bit) if use_moe_wna16_cuda: - config = {"BLOCK_SIZE_M": min(16, M)} + config = {"BLOCK_SIZE_M": min(16, M), "SPLIT_K": 1} elif M <= 20: - config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1} + config = {"BLOCK_SIZE_M": 16, "GROUP_SIZE_M": 1, "SPLIT_K": 1} elif M <= 40: - config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1} + config = {"BLOCK_SIZE_M": 32, "GROUP_SIZE_M": 1, "SPLIT_K": 1} else: - config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1} + config = {"BLOCK_SIZE_M": 64, "GROUP_SIZE_M": 1, "SPLIT_K": 1} elif M <= E: config = { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "SPLIT_K": 1, } else: config = { @@ -1026,6 +1031,7 @@ def get_default_config( "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, + "SPLIT_K": 1, } return config