diff --git a/vllm/model_executor/layers/moe/fused_batched_moe.py b/vllm/model_executor/layers/moe/fused_batched_moe.py new file mode 100644 index 0000000000000..a5e5c06d864d1 --- /dev/null +++ b/vllm/model_executor/layers/moe/fused_batched_moe.py @@ -0,0 +1,463 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Fused batched MoE kernel.""" + +import torch + +from vllm.model_executor.layers.fused_moe.utils import ( + normalize_batched_scales_shape, +) +from vllm.triton_utils import tl, triton + + +@triton.jit +def moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # Offsets and masks + offs_m, + offs_n, + offs_bn, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, +): + offs_k = tl.arange(0, BLOCK_K) + + if use_w8a16: + b_scale_ptrs = ( + b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn + ) + b_scale = tl.load(b_scale_ptrs) + + if use_w8a8: + # block-wise + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn + + # per act token + elif per_act_token_quant: + # Load per-token scale for activations + a_scale_ptrs = a_scale_ptr + offs_m * stride_asm + a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None] + + b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + # tensor-wise + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load( + a_ptrs, + mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0) + # We accumulate along the K dimension. + if use_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_w8a8: + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=mask_m, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + # acc used to enable fp8_fast_accum + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + + if use_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_w8a8: + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + + return accumulator + + +@triton.jit +def expert_triton_kernel( + a_ptr, # [max_tokens, K] + b_ptr, # [K, N] + c_ptr, # [max_tokens, N] + expert_id, + compute_type: tl.constexpr, + # Dimensions + M, + N, + K, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # strides + stride_am: tl.int64, + stride_ak: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # offsets + offs_bn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + offs_m = tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) % N + offs_k = tl.arange(0, BLOCK_K) + mask_m = offs_m < M + + # Make grids of a + b pointers + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn + + accumulator = moe_mmk( + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + offs_bn, + mask_m, + # Block size for block-wise quantization + group_n, + group_k, + # Meta-parameters + BLOCK_M, + BLOCK_N, + BLOCK_K, + compute_type, + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant, + ) + + # store in C + offs_cn = tl.arange(0, BLOCK_N) + c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn + c_mask = mask_m[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.jit +def batched_triton_kernel( + a_ptr, # [E, max_num_tokens, K] + b_ptr, # [E, K, N] + c_ptr, # [E, max_num_tokens, N] + expert_num_tokens, # [E] + compute_type: tl.constexpr, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ae: tl.int64, + stride_am: tl.int64, + stride_ak: tl.int64, + stride_be: tl.int64, + stride_bk: tl.int64, + stride_bn: tl.int64, + stride_ce: tl.int64, + stride_cm: tl.int64, + stride_cn: tl.int64, + stride_ase: tl.int64, + stride_asm: tl.int64, + stride_ask: tl.int64, + stride_bse: tl.int64, + stride_bsk: tl.int64, + stride_bsn: tl.int64, + # Blockwise quantization data + group_n: tl.constexpr, + group_k: tl.constexpr, + # Quantization schemes + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr, + per_act_token_quant: tl.constexpr, + # Kernel config + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + expert_id = tl.program_id(axis=0) + e_num_tokens = tl.load(expert_num_tokens + expert_id) + if e_num_tokens == 0: + # Early exit + return + + # axis 1 is M_blocks * N_blocks + pid_mn = tl.program_id(axis=1) + # num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid_mn // num_pid_n + pid_n = pid_mn % num_pid_n + + cta_m_start = pid_m * BLOCK_M + cta_n_start = pid_n * BLOCK_N + if cta_m_start >= e_num_tokens: + # Early exit + return + + cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start) + cta_n_size = min(BLOCK_N, N - cta_n_start) + + a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am + b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn + c_ptr = ( + c_ptr + + expert_id * stride_ce + + cta_m_start * stride_cm + + cta_n_start * stride_cn + ) + + offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N + + if use_fp8_w8a8: + a_scale_ptr = a_scale_ptr + expert_id * stride_ase + b_scale_ptr = b_scale_ptr + expert_id * stride_bse + + # block-wise + if group_k > 0 and group_n > 0 or per_act_token_quant: + a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm + + expert_triton_kernel( + a_ptr, + b_ptr, + c_ptr, + expert_id, + compute_type, + cta_m_size, # M + cta_n_size, # N + K, # K + a_scale_ptr, + b_scale_ptr, + b_zp_ptr, + # Strides + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # offsets + offs_bn, + # Blockwise quantization data + group_n, + group_k, + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant, + # Kernel config + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + + +def invoke_moe_batched_triton_kernel( + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, N, K] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: torch.Tensor | None, + B_scale: torch.Tensor | None, + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + per_act_token_quant: bool, + block_shape: list[int] | None = None, +): + assert not use_int4_w4a16 + max_num_tokens = A.size(1) + K = A.size(2) + N = C.size(2) + + BLOCK_M = config["BLOCK_SIZE_M"] + BLOCK_N = config["BLOCK_SIZE_N"] + BLOCK_K = config["BLOCK_SIZE_K"] + + grid = ( + expert_num_tokens.size(0), + triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N), + ) + + A_scale = normalize_batched_scales_shape(A_scale, expert_num_tokens.shape[0]) + + if B_scale is not None and B_scale.ndim == 1: + assert B_scale.numel() == expert_num_tokens.shape[0] + B_scale = B_scale.view(-1, 1, 1) + + assert A_scale is None or A_scale.ndim == 3, ( + f"{0 if A_scale is None else A_scale.shape}" + ) + assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, ( + f"{0 if B_scale is None else B_scale.shape}" + ) + + if B_scale is not None: + if B_scale.ndim == 1: + stride_bse = 1 + stride_bsk = 0 + stride_bsn = 0 + else: + stride_bse = B_scale.stride(0) + stride_bsk = B_scale.stride(2) + stride_bsn = B_scale.stride(1) + + else: + stride_bse = 0 + stride_bsk = 0 + stride_bsn = 0 + + if A_scale is not None: + stride_ase = A_scale.stride(0) + stride_asm = A_scale.stride(1) + stride_ask = A_scale.stride(2) + else: + stride_ase = 0 + stride_asm = 0 + stride_ask = 0 + + batched_triton_kernel[grid]( + A, + B, + C, + expert_num_tokens, + compute_type, + # Dimensions + max_num_tokens, + K, + N, + # Quantization data + A_scale, + B_scale, + B_zp, + # Strides + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(0), + C.stride(1), + C.stride(2), + stride_ase, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Blockwise quantization data + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + # Quantization schemes + use_fp8_w8a8, + use_int8_w8a16, + per_act_token_quant, + # Kernel config + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + ) diff --git a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py index 245fa82a23cef..c23d70bda23e0 100644 --- a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py +++ b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py @@ -12,6 +12,9 @@ import torch from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size, ) +from vllm.model_executor.layers.moe.fused_batched_moe import ( + invoke_moe_batched_triton_kernel, +) from vllm.model_executor.layers.moe.fused_moe import invoke_fused_moe_kernel from vllm.triton_utils import tl from vllm.utils import deep_gemm as vllm_deep_gemm @@ -205,6 +208,73 @@ def run_triton_group_gemm_contiguous_bf16( torch.testing.assert_close(output, reference_output) +def run_triton_group_gemm_masked_bf16( + expected_group_batch_size: int, + num_groups: int, + output_size: int, + input_size: int, +): + weight = torch.randn( + num_groups, + output_size, + input_size, + dtype=torch.bfloat16, + device="cuda", + ) + group_batch_size = [ + int(expected_group_batch_size * random.uniform(0.7, 1.3)) + for _ in range(num_groups) + ] + batch_size = sum(group_batch_size) + group_batch_size = torch.tensor( + group_batch_size, + dtype=torch.int32, + device="cuda", + ) + x = torch.randn( + num_groups, + batch_size, + input_size, + dtype=torch.bfloat16, + device="cuda", + ) + ground_truth_output = torch.einsum("gmk, gnk -> gmn", x, weight) + output = torch.zeros( + num_groups, + batch_size, + output_size, + dtype=torch.bfloat16, + device="cuda", + ) + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + invoke_moe_batched_triton_kernel( + A=x, + B=weight, + C=output, + expert_num_tokens=group_batch_size, + compute_type=tl.bfloat16, + A_scale=None, + B_scale=None, + B_zp=None, + use_fp8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + config=config, + per_act_token_quant=False, + ) + for i in range(num_groups): + torch.testing.assert_close( + output[i, : group_batch_size[i]], + ground_truth_output[i, : group_batch_size[i]], + ) + + # run_batched_deepgemm_masked_fp8(512, 8, 1024, 512) run_batched_deepgemm_masked_bf16(512, 8, 1024, 512) run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4) +run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)