add triton_group_gemm_masked

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-15 14:52:31 -07:00
parent a608dfab45
commit 850876a183
2 changed files with 533 additions and 0 deletions

View File

@ -0,0 +1,463 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused batched MoE kernel."""
import torch
from vllm.model_executor.layers.fused_moe.utils import (
normalize_batched_scales_shape,
)
from vllm.triton_utils import tl, triton
@triton.jit
def moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak: tl.int64,
stride_bk: tl.int64,
stride_ase: tl.int64,
stride_asm: tl.int64,
stride_ask: tl.int64,
stride_bse: tl.int64,
stride_bsk: tl.int64,
stride_bsn: tl.int64,
# Offsets and masks
offs_m,
offs_n,
offs_bn,
mask_m,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
compute_type: tl.constexpr,
use_w8a8: tl.constexpr,
use_w8a16: tl.constexpr,
per_act_token_quant: tl.constexpr,
):
offs_k = tl.arange(0, BLOCK_K)
if use_w8a16:
b_scale_ptrs = (
b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn
)
b_scale = tl.load(b_scale_ptrs)
if use_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
offs_bsn = offs_bn // group_n
b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn
# per act token
elif per_act_token_quant:
# Load per-token scale for activations
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None]
b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
# tensor-wise
else:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(
a_ptrs,
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
# We accumulate along the K dimension.
if use_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=mask_m, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
# acc used to enable fp8_fast_accum
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
if use_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
return accumulator
@triton.jit
def expert_triton_kernel(
a_ptr, # [max_tokens, K]
b_ptr, # [K, N]
c_ptr, # [max_tokens, N]
expert_id,
compute_type: tl.constexpr,
# Dimensions
M,
N,
K,
# Quantization data
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# strides
stride_am: tl.int64,
stride_ak: tl.int64,
stride_bk: tl.int64,
stride_bn: tl.int64,
stride_cm: tl.int64,
stride_cn: tl.int64,
stride_ase: tl.int64,
stride_asm: tl.int64,
stride_ask: tl.int64,
stride_bse: tl.int64,
stride_bsk: tl.int64,
stride_bsn: tl.int64,
# offsets
offs_bn,
# Blockwise quantization data
group_n,
group_k,
# Quantization schemes
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_act_token_quant: tl.constexpr,
# Kernel config
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
offs_m = tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N) % N
offs_k = tl.arange(0, BLOCK_K)
mask_m = offs_m < M
# Make grids of a + b pointers
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
accumulator = moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_ase,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
offs_bn,
mask_m,
# Block size for block-wise quantization
group_n,
group_k,
# Meta-parameters
BLOCK_M,
BLOCK_N,
BLOCK_K,
compute_type,
use_fp8_w8a8,
use_int8_w8a16,
per_act_token_quant,
)
# store in C
offs_cn = tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def batched_triton_kernel(
a_ptr, # [E, max_num_tokens, K]
b_ptr, # [E, K, N]
c_ptr, # [E, max_num_tokens, N]
expert_num_tokens, # [E]
compute_type: tl.constexpr,
# Dimensions
max_num_tokens,
K,
N,
# Quantization data
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ae: tl.int64,
stride_am: tl.int64,
stride_ak: tl.int64,
stride_be: tl.int64,
stride_bk: tl.int64,
stride_bn: tl.int64,
stride_ce: tl.int64,
stride_cm: tl.int64,
stride_cn: tl.int64,
stride_ase: tl.int64,
stride_asm: tl.int64,
stride_ask: tl.int64,
stride_bse: tl.int64,
stride_bsk: tl.int64,
stride_bsn: tl.int64,
# Blockwise quantization data
group_n: tl.constexpr,
group_k: tl.constexpr,
# Quantization schemes
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr,
per_act_token_quant: tl.constexpr,
# Kernel config
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
expert_id = tl.program_id(axis=0)
e_num_tokens = tl.load(expert_num_tokens + expert_id)
if e_num_tokens == 0:
# Early exit
return
# axis 1 is M_blocks * N_blocks
pid_mn = tl.program_id(axis=1)
# num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid_mn // num_pid_n
pid_n = pid_mn % num_pid_n
cta_m_start = pid_m * BLOCK_M
cta_n_start = pid_n * BLOCK_N
if cta_m_start >= e_num_tokens:
# Early exit
return
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
cta_n_size = min(BLOCK_N, N - cta_n_start)
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
c_ptr = (
c_ptr
+ expert_id * stride_ce
+ cta_m_start * stride_cm
+ cta_n_start * stride_cn
)
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N
if use_fp8_w8a8:
a_scale_ptr = a_scale_ptr + expert_id * stride_ase
b_scale_ptr = b_scale_ptr + expert_id * stride_bse
# block-wise
if group_k > 0 and group_n > 0 or per_act_token_quant:
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
expert_triton_kernel(
a_ptr,
b_ptr,
c_ptr,
expert_id,
compute_type,
cta_m_size, # M
cta_n_size, # N
K, # K
a_scale_ptr,
b_scale_ptr,
b_zp_ptr,
# Strides
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_ase,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# offsets
offs_bn,
# Blockwise quantization data
group_n,
group_k,
# Quantization schemes
use_fp8_w8a8,
use_int8_w8a16,
per_act_token_quant,
# Kernel config
BLOCK_M,
BLOCK_N,
BLOCK_K,
)
def invoke_moe_batched_triton_kernel(
A: torch.Tensor, # [E, max_tokens, K]
B: torch.Tensor, # [E, N, K]
C: torch.Tensor, # [E, max_tokens, N]
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
config: dict[str, int],
per_act_token_quant: bool,
block_shape: list[int] | None = None,
):
assert not use_int4_w4a16
max_num_tokens = A.size(1)
K = A.size(2)
N = C.size(2)
BLOCK_M = config["BLOCK_SIZE_M"]
BLOCK_N = config["BLOCK_SIZE_N"]
BLOCK_K = config["BLOCK_SIZE_K"]
grid = (
expert_num_tokens.size(0),
triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N),
)
A_scale = normalize_batched_scales_shape(A_scale, expert_num_tokens.shape[0])
if B_scale is not None and B_scale.ndim == 1:
assert B_scale.numel() == expert_num_tokens.shape[0]
B_scale = B_scale.view(-1, 1, 1)
assert A_scale is None or A_scale.ndim == 3, (
f"{0 if A_scale is None else A_scale.shape}"
)
assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, (
f"{0 if B_scale is None else B_scale.shape}"
)
if B_scale is not None:
if B_scale.ndim == 1:
stride_bse = 1
stride_bsk = 0
stride_bsn = 0
else:
stride_bse = B_scale.stride(0)
stride_bsk = B_scale.stride(2)
stride_bsn = B_scale.stride(1)
else:
stride_bse = 0
stride_bsk = 0
stride_bsn = 0
if A_scale is not None:
stride_ase = A_scale.stride(0)
stride_asm = A_scale.stride(1)
stride_ask = A_scale.stride(2)
else:
stride_ase = 0
stride_asm = 0
stride_ask = 0
batched_triton_kernel[grid](
A,
B,
C,
expert_num_tokens,
compute_type,
# Dimensions
max_num_tokens,
K,
N,
# Quantization data
A_scale,
B_scale,
B_zp,
# Strides
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(0),
C.stride(1),
C.stride(2),
stride_ase,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Blockwise quantization data
0 if block_shape is None else block_shape[0],
0 if block_shape is None else block_shape[1],
# Quantization schemes
use_fp8_w8a8,
use_int8_w8a16,
per_act_token_quant,
# Kernel config
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)

View File

@ -12,6 +12,9 @@ import torch
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size,
)
from vllm.model_executor.layers.moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel,
)
from vllm.model_executor.layers.moe.fused_moe import invoke_fused_moe_kernel
from vllm.triton_utils import tl
from vllm.utils import deep_gemm as vllm_deep_gemm
@ -205,6 +208,73 @@ def run_triton_group_gemm_contiguous_bf16(
torch.testing.assert_close(output, reference_output)
def run_triton_group_gemm_masked_bf16(
expected_group_batch_size: int,
num_groups: int,
output_size: int,
input_size: int,
):
weight = torch.randn(
num_groups,
output_size,
input_size,
dtype=torch.bfloat16,
device="cuda",
)
group_batch_size = [
int(expected_group_batch_size * random.uniform(0.7, 1.3))
for _ in range(num_groups)
]
batch_size = sum(group_batch_size)
group_batch_size = torch.tensor(
group_batch_size,
dtype=torch.int32,
device="cuda",
)
x = torch.randn(
num_groups,
batch_size,
input_size,
dtype=torch.bfloat16,
device="cuda",
)
ground_truth_output = torch.einsum("gmk, gnk -> gmn", x, weight)
output = torch.zeros(
num_groups,
batch_size,
output_size,
dtype=torch.bfloat16,
device="cuda",
)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
invoke_moe_batched_triton_kernel(
A=x,
B=weight,
C=output,
expert_num_tokens=group_batch_size,
compute_type=tl.bfloat16,
A_scale=None,
B_scale=None,
B_zp=None,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
config=config,
per_act_token_quant=False,
)
for i in range(num_groups):
torch.testing.assert_close(
output[i, : group_batch_size[i]],
ground_truth_output[i, : group_batch_size[i]],
)
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)