mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 17:54:30 +08:00
add triton_group_gemm_masked
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
a608dfab45
commit
850876a183
463
vllm/model_executor/layers/moe/fused_batched_moe.py
Normal file
463
vllm/model_executor/layers/moe/fused_batched_moe.py
Normal file
@ -0,0 +1,463 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Fused batched MoE kernel."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
normalize_batched_scales_shape,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_mmk(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak: tl.int64,
|
||||
stride_bk: tl.int64,
|
||||
stride_ase: tl.int64,
|
||||
stride_asm: tl.int64,
|
||||
stride_ask: tl.int64,
|
||||
stride_bse: tl.int64,
|
||||
stride_bsk: tl.int64,
|
||||
stride_bsn: tl.int64,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_bn,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_w8a8: tl.constexpr,
|
||||
use_w8a16: tl.constexpr,
|
||||
per_act_token_quant: tl.constexpr,
|
||||
):
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
if use_w8a16:
|
||||
b_scale_ptrs = (
|
||||
b_scale_ptr + expert_id * stride_bse + offs_n[None, :] * stride_bsn
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_w8a8:
|
||||
# block-wise
|
||||
if group_k > 0 and group_n > 0:
|
||||
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
|
||||
offs_bsn = offs_bn // group_n
|
||||
b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bsn
|
||||
|
||||
# per act token
|
||||
elif per_act_token_quant:
|
||||
# Load per-token scale for activations
|
||||
a_scale_ptrs = a_scale_ptr + offs_m * stride_asm
|
||||
a_scale = tl.load(a_scale_ptrs, mask=mask_m, other=0.0)[:, None]
|
||||
|
||||
b_scale_ptrs = b_scale_ptr + offs_bn[None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
# tensor-wise
|
||||
else:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=mask_m[:, None] & (offs_k[None, :] < K - k * BLOCK_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
k_start = k * BLOCK_K
|
||||
offs_ks = k_start // group_k
|
||||
a_scale = tl.load(
|
||||
a_scale_ptrs + offs_ks * stride_ask, mask=mask_m, other=0.0
|
||||
)
|
||||
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
||||
else:
|
||||
# acc used to enable fp8_fast_accum
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_K * stride_ak
|
||||
b_ptrs += BLOCK_K * stride_bk
|
||||
|
||||
if use_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_w8a8:
|
||||
if group_k > 0 and group_n > 0:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
else:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
|
||||
return accumulator
|
||||
|
||||
|
||||
@triton.jit
|
||||
def expert_triton_kernel(
|
||||
a_ptr, # [max_tokens, K]
|
||||
b_ptr, # [K, N]
|
||||
c_ptr, # [max_tokens, N]
|
||||
expert_id,
|
||||
compute_type: tl.constexpr,
|
||||
# Dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Quantization data
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
# strides
|
||||
stride_am: tl.int64,
|
||||
stride_ak: tl.int64,
|
||||
stride_bk: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
stride_cm: tl.int64,
|
||||
stride_cn: tl.int64,
|
||||
stride_ase: tl.int64,
|
||||
stride_asm: tl.int64,
|
||||
stride_ask: tl.int64,
|
||||
stride_bse: tl.int64,
|
||||
stride_bsk: tl.int64,
|
||||
stride_bsn: tl.int64,
|
||||
# offsets
|
||||
offs_bn,
|
||||
# Blockwise quantization data
|
||||
group_n,
|
||||
group_k,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_act_token_quant: tl.constexpr,
|
||||
# Kernel config
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N) % N
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
mask_m = offs_m < M
|
||||
|
||||
# Make grids of a + b pointers
|
||||
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
|
||||
|
||||
accumulator = moe_mmk(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_ase,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
offs_bn,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n,
|
||||
group_k,
|
||||
# Meta-parameters
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
# store in C
|
||||
offs_cn = tl.arange(0, BLOCK_N)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
|
||||
c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batched_triton_kernel(
|
||||
a_ptr, # [E, max_num_tokens, K]
|
||||
b_ptr, # [E, K, N]
|
||||
c_ptr, # [E, max_num_tokens, N]
|
||||
expert_num_tokens, # [E]
|
||||
compute_type: tl.constexpr,
|
||||
# Dimensions
|
||||
max_num_tokens,
|
||||
K,
|
||||
N,
|
||||
# Quantization data
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ae: tl.int64,
|
||||
stride_am: tl.int64,
|
||||
stride_ak: tl.int64,
|
||||
stride_be: tl.int64,
|
||||
stride_bk: tl.int64,
|
||||
stride_bn: tl.int64,
|
||||
stride_ce: tl.int64,
|
||||
stride_cm: tl.int64,
|
||||
stride_cn: tl.int64,
|
||||
stride_ase: tl.int64,
|
||||
stride_asm: tl.int64,
|
||||
stride_ask: tl.int64,
|
||||
stride_bse: tl.int64,
|
||||
stride_bsk: tl.int64,
|
||||
stride_bsn: tl.int64,
|
||||
# Blockwise quantization data
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr,
|
||||
per_act_token_quant: tl.constexpr,
|
||||
# Kernel config
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
):
|
||||
expert_id = tl.program_id(axis=0)
|
||||
e_num_tokens = tl.load(expert_num_tokens + expert_id)
|
||||
if e_num_tokens == 0:
|
||||
# Early exit
|
||||
return
|
||||
|
||||
# axis 1 is M_blocks * N_blocks
|
||||
pid_mn = tl.program_id(axis=1)
|
||||
# num_pid_m = tl.cdiv(max_num_tokens, BLOCK_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid_mn // num_pid_n
|
||||
pid_n = pid_mn % num_pid_n
|
||||
|
||||
cta_m_start = pid_m * BLOCK_M
|
||||
cta_n_start = pid_n * BLOCK_N
|
||||
if cta_m_start >= e_num_tokens:
|
||||
# Early exit
|
||||
return
|
||||
|
||||
cta_m_size = min(BLOCK_M, e_num_tokens - cta_m_start)
|
||||
cta_n_size = min(BLOCK_N, N - cta_n_start)
|
||||
|
||||
a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
|
||||
b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
|
||||
c_ptr = (
|
||||
c_ptr
|
||||
+ expert_id * stride_ce
|
||||
+ cta_m_start * stride_cm
|
||||
+ cta_n_start * stride_cn
|
||||
)
|
||||
|
||||
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(tl.int64)) % N
|
||||
|
||||
if use_fp8_w8a8:
|
||||
a_scale_ptr = a_scale_ptr + expert_id * stride_ase
|
||||
b_scale_ptr = b_scale_ptr + expert_id * stride_bse
|
||||
|
||||
# block-wise
|
||||
if group_k > 0 and group_n > 0 or per_act_token_quant:
|
||||
a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
|
||||
|
||||
expert_triton_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
expert_id,
|
||||
compute_type,
|
||||
cta_m_size, # M
|
||||
cta_n_size, # N
|
||||
K, # K
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
# Strides
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_ase,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# offsets
|
||||
offs_bn,
|
||||
# Blockwise quantization data
|
||||
group_n,
|
||||
group_k,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_act_token_quant,
|
||||
# Kernel config
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
)
|
||||
|
||||
|
||||
def invoke_moe_batched_triton_kernel(
|
||||
A: torch.Tensor, # [E, max_tokens, K]
|
||||
B: torch.Tensor, # [E, N, K]
|
||||
C: torch.Tensor, # [E, max_tokens, N]
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
):
|
||||
assert not use_int4_w4a16
|
||||
max_num_tokens = A.size(1)
|
||||
K = A.size(2)
|
||||
N = C.size(2)
|
||||
|
||||
BLOCK_M = config["BLOCK_SIZE_M"]
|
||||
BLOCK_N = config["BLOCK_SIZE_N"]
|
||||
BLOCK_K = config["BLOCK_SIZE_K"]
|
||||
|
||||
grid = (
|
||||
expert_num_tokens.size(0),
|
||||
triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N),
|
||||
)
|
||||
|
||||
A_scale = normalize_batched_scales_shape(A_scale, expert_num_tokens.shape[0])
|
||||
|
||||
if B_scale is not None and B_scale.ndim == 1:
|
||||
assert B_scale.numel() == expert_num_tokens.shape[0]
|
||||
B_scale = B_scale.view(-1, 1, 1)
|
||||
|
||||
assert A_scale is None or A_scale.ndim == 3, (
|
||||
f"{0 if A_scale is None else A_scale.shape}"
|
||||
)
|
||||
assert B_scale is None or B_scale.ndim == 1 or B_scale.ndim == 3, (
|
||||
f"{0 if B_scale is None else B_scale.shape}"
|
||||
)
|
||||
|
||||
if B_scale is not None:
|
||||
if B_scale.ndim == 1:
|
||||
stride_bse = 1
|
||||
stride_bsk = 0
|
||||
stride_bsn = 0
|
||||
else:
|
||||
stride_bse = B_scale.stride(0)
|
||||
stride_bsk = B_scale.stride(2)
|
||||
stride_bsn = B_scale.stride(1)
|
||||
|
||||
else:
|
||||
stride_bse = 0
|
||||
stride_bsk = 0
|
||||
stride_bsn = 0
|
||||
|
||||
if A_scale is not None:
|
||||
stride_ase = A_scale.stride(0)
|
||||
stride_asm = A_scale.stride(1)
|
||||
stride_ask = A_scale.stride(2)
|
||||
else:
|
||||
stride_ase = 0
|
||||
stride_asm = 0
|
||||
stride_ask = 0
|
||||
|
||||
batched_triton_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
expert_num_tokens,
|
||||
compute_type,
|
||||
# Dimensions
|
||||
max_num_tokens,
|
||||
K,
|
||||
N,
|
||||
# Quantization data
|
||||
A_scale,
|
||||
B_scale,
|
||||
B_zp,
|
||||
# Strides
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
stride_ase,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Blockwise quantization data
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
per_act_token_quant,
|
||||
# Kernel config
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_K=BLOCK_K,
|
||||
)
|
||||
@ -12,6 +12,9 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.model_executor.layers.moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.moe.fused_moe import invoke_fused_moe_kernel
|
||||
from vllm.triton_utils import tl
|
||||
from vllm.utils import deep_gemm as vllm_deep_gemm
|
||||
@ -205,6 +208,73 @@ def run_triton_group_gemm_contiguous_bf16(
|
||||
torch.testing.assert_close(output, reference_output)
|
||||
|
||||
|
||||
def run_triton_group_gemm_masked_bf16(
|
||||
expected_group_batch_size: int,
|
||||
num_groups: int,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
):
|
||||
weight = torch.randn(
|
||||
num_groups,
|
||||
output_size,
|
||||
input_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
group_batch_size = [
|
||||
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
batch_size = sum(group_batch_size)
|
||||
group_batch_size = torch.tensor(
|
||||
group_batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
x = torch.randn(
|
||||
num_groups,
|
||||
batch_size,
|
||||
input_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
ground_truth_output = torch.einsum("gmk, gnk -> gmn", x, weight)
|
||||
output = torch.zeros(
|
||||
num_groups,
|
||||
batch_size,
|
||||
output_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
invoke_moe_batched_triton_kernel(
|
||||
A=x,
|
||||
B=weight,
|
||||
C=output,
|
||||
expert_num_tokens=group_batch_size,
|
||||
compute_type=tl.bfloat16,
|
||||
A_scale=None,
|
||||
B_scale=None,
|
||||
B_zp=None,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
config=config,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
for i in range(num_groups):
|
||||
torch.testing.assert_close(
|
||||
output[i, : group_batch_size[i]],
|
||||
ground_truth_output[i, : group_batch_size[i]],
|
||||
)
|
||||
|
||||
|
||||
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
|
||||
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
||||
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
||||
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user