[Minor] Fused experts refactor (#15914)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm 2025-04-03 13:19:38 -04:00 committed by GitHub
parent d2b58ca203
commit 15ba07ef25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 790 additions and 737 deletions

View File

@ -9,8 +9,11 @@ import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
@ -437,7 +440,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
pytest.skip(
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
if (N <= 512):
if N <= 512:
pytest.skip("Skipping N <= 512 until performance issues solved.")
vllm_config = VllmConfig()

View File

@ -4,8 +4,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
fused_experts,
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
fused_topk)
from vllm.platforms import current_platform
@ -131,9 +131,9 @@ def test_cutlass_moe_no_graph(
c_strides2,
a1_scale=a_scale1)
print(triton_output)
print(cutlass_output)
print("*")
#print(triton_output)
#print(cutlass_output)
#print("*")
torch.testing.assert_close(triton_output,
cutlass_output,
@ -234,9 +234,9 @@ def test_cutlass_moe_cuda_graph(
graph.replay()
torch.cuda.synchronize()
print(triton_output)
print(cutlass_output)
print("*")
#print(triton_output)
#print(cutlass_output)
#print("*")
torch.testing.assert_close(triton_output,
cutlass_output,

View File

@ -35,9 +35,11 @@ if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
cutlass_moe_fp8, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
__all__ += [
"fused_moe",

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
from typing import Optional
import torch
from vllm import _custom_ops as ops
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1_q.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2_q.shape[2], "W2 scale shape mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert ab_strides1.shape[0] == w1_q.shape[
0], "AB Strides 1 expert number mismatch"
assert c_strides1.shape[0] == w1_q.shape[
0], "C Strides 1 expert number mismatch"
assert ab_strides2.shape[0] == w2_q.shape[
0], "AB Strides 2 expert number mismatch"
assert c_strides2.shape[0] == w2_q.shape[
0], "C Strides 2 expert number mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
topk = topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
device = a_q.device
expert_offsets = torch.empty((num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,
ab_strides1, c_strides1)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
intemediate_q, a2_scale = ops.scaled_fp8_quant(
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
return (c2[c_map].view(m, topk, k) *
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)

View File

@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.util
from typing import Optional, Tuple
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
_fp8_quantize,
_resize_cache)
from vllm.utils import round_up
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
def _valid_deep_gemm(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
expert_map: Optional[torch.Tensor] = None) -> bool:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
return False
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
# Expert maps not supported yet.
if expert_map is not None:
return False
align = dg.get_m_alignment_for_contiguous_layout()
M = hidden_states.shape[0]
_, K, N = w2.shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if N <= 512:
return False
if align > M or N % align != 0 or K % align != 0:
return False
return (hidden_states.is_contiguous() and w1.is_contiguous()
and w2.is_contiguous())
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
block_m: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
top_k_num = curr_topk_ids.shape[1]
tokens_in_chunk, _ = curr_hidden_states.shape
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
global_num_experts,
expert_map,
pad_sorted_ids=True))
inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk_weight: torch.Tensor,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M, topk = topk_weight.shape
K = curr_hidden.shape[1]
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def deep_gemm_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
assert expert_map is None, "Expert maps not supported yet"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
0] == hidden_states.shape[0], "Input scale shape mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
block_m = dg.get_m_alignment_for_contiguous_layout()
block_shape = [block_m, block_m]
assert w1_scale is not None
assert w2_scale is not None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
num_chunks = (num_tokens // CHUNK_SIZE) + 1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.empty(M_sum * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
workspace2 = torch.empty((M_sum, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
for chunk in range(num_chunks):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
a1_scale, block_shape)
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
curr_topk_ids, global_num_experts,
expert_map, block_m)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
curr_M = sorted_token_ids.numel()
workspace1 = _resize_cache(workspace1, (curr_M, N))
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
workspace3 = _resize_cache(workspace3, (curr_M, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
expert_ids)
if activation == "silu":
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
_moe_unpermute_and_reduce(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
return out_hidden_states

View File

@ -1,10 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
"""Fused MoE kernel."""
import functools
import importlib.util
import json
import os
from math import prod
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
@ -14,10 +12,13 @@ import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op, round_up
from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
rocm_aiter_fused_experts,
@ -25,8 +26,6 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
@ -443,300 +442,13 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor,
expert_map: Optional[torch.Tensor]) -> bool:
"""
Check if the given problem size is supported by the DeepGemm grouped
gemm kernel. All of M, N, K and the quantization block_shape must be
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
"""
if not has_deep_gemm:
return False
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
# Expert maps not supported yet.
if expert_map is not None:
return False
align = dg.get_m_alignment_for_contiguous_layout()
M = hidden_states.shape[0]
_, K, N = w2.shape
# For now, disable DeepGemm for small N until better permute/unpermute
# ops are available.
if N <= 512:
return False
if align > M or N % align != 0 or K % align != 0:
return False
return (hidden_states.is_contiguous() and w1.is_contiguous()
and w2.is_contiguous())
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: Optional[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale
def invoke_fused_moe_kernel(A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@ -748,7 +460,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None) -> None:
assert topk_weights.stride(1) == 1
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if use_fp8_w8a8:
@ -765,6 +478,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert A_scale is None
assert B_scale is None
M = A.shape[0]
num_tokens = M * top_k
EM = sorted_token_ids.shape[0]
if A.shape[0] < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
@ -782,7 +498,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert B_zp is None or B_zp.ndim == 3
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
num_valid_tokens=topk_ids.numel(),
num_valid_tokens=num_tokens,
group_size=block_shape[1],
num_experts=B.shape[0],
bit=4 if use_int4_w4a16 else 8)
@ -790,12 +506,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
config.update(
get_moe_wna16_block_config(config=config,
use_moe_wna16_cuda=use_moe_wna16_cuda,
num_valid_tokens=topk_ids.numel(),
num_valid_tokens=num_tokens,
size_k=A.shape[1],
size_n=B.shape[1],
num_experts=B.shape[1],
group_size=block_shape[1],
real_top_k=topk_ids.shape[1],
real_top_k=top_k,
block_size_m=config["BLOCK_SIZE_M"]))
if use_moe_wna16_cuda:
@ -821,7 +537,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B.shape[1],
A.shape[1],
EM,
topk_ids.numel(),
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
@ -864,7 +580,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B.shape[1],
B.shape[2],
EM,
topk_ids.numel(),
num_tokens,
A.stride(0),
A.stride(1),
B.stride(0),
@ -1389,6 +1105,7 @@ def fused_experts(hidden_states: torch.Tensor,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=inplace,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
@ -1419,85 +1136,6 @@ def fused_experts(hidden_states: torch.Tensor,
block_shape=block_shape)
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
A permutation routine that works on fp8 types.
"""
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def _moe_permute(
curr_hidden_states: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
curr_topk_ids: torch.Tensor,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
top_k_num: int,
block_m: int,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.
"""
tokens_in_chunk, _ = curr_hidden_states.shape
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(curr_topk_ids,
block_m,
global_num_experts,
expert_map,
pad_sorted_ids=True))
inv_perm: Optional[torch.Tensor] = None
num_tokens = top_k_num * tokens_in_chunk
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
# Permute according to sorted token ids.
curr_hidden_states = _fp8_perm(curr_hidden_states,
sorted_token_ids // top_k_num)
if a1q_scale is not None:
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm)
def _moe_unpermute_and_reduce(
out: torch.Tensor,
curr_hidden: torch.Tensor,
inv_perm: Optional[torch.Tensor],
topk: int,
K: int,
topk_weight: torch.Tensor,
) -> None:
"""
Unpermute the final result and apply topk_weights, then perform the final
reduction on the hidden states.
"""
M = topk_weight.shape[0]
curr_hidden = curr_hidden[inv_perm, ...]
curr_hidden = curr_hidden.view(-1, topk, K)
curr_hidden.mul_(topk_weight.view(M, -1, 1))
ops.moe_sum(curr_hidden, out)
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(v) <= x.numel()
return x.flatten()[:prod(v)].view(*v)
def fused_experts_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -1629,7 +1267,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
@ -1660,18 +1297,18 @@ def fused_experts_impl(hidden_states: torch.Tensor,
qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
invoke_fused_moe_kernel(qintermediate_cache2,
invoke_fused_moe_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
a2q_scale,
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
False, #True,
1,
config,
compute_type=compute_type,
@ -1680,8 +1317,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape)
if True:
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
intermediate_cache3.mul_(
curr_topk_weights.view(tokens_in_chunk, -1, 1))
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
@ -1790,327 +1433,3 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
def deep_gemm_moe_fp8(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with DeepGemm
grouped gemm.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
Returns:
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
# Lazy import to avoid CUDA initialization problems.
import deep_gemm as dg
assert expert_map is None, "Expert maps not supported yet"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert w1.dtype == torch.float8_e4m3fn
assert w2.dtype == torch.float8_e4m3fn
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
0] == hidden_states.shape[0], "Input scale shape mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
K = w2.shape[1]
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.shape[1]
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
block_m = dg.get_m_alignment_for_contiguous_layout()
block_shape = [block_m, block_m]
assert w1_scale is not None
assert w2_scale is not None
# We attempt to transpose and align offline in Fp8MoEMethod, in which
# case these calls will be nops. Otherwise, they'll be performed every
# time the layer is executed.
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
num_chunks = (num_tokens // CHUNK_SIZE) + 1
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
cache13 = torch.empty(M_sum * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N)
intermediate_cache2 = torch.empty((M_sum, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K)
for chunk in range(num_chunks):
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
min((chunk + 1) * CHUNK_SIZE,
num_tokens))
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
tokens_in_chunk, _ = curr_hidden_states.shape
if tokens_in_chunk == 0:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
a1q_scale: Optional[torch.Tensor] = None
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
a1_scale, block_shape)
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
curr_topk_ids, global_num_experts,
expert_map, top_k_num, block_m)
# Adjust the intermediate cache size and config for the last chunk.
# Note that in most cases we only have one chunk so the cache size
# and config are already set correctly and do not need to be adjusted.
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
curr_M = sorted_token_ids.numel()
intermediate_cache1 = _resize_cache(intermediate_cache1,
(curr_M, N))
intermediate_cache2 = _resize_cache(intermediate_cache2,
(curr_M, N // 2))
intermediate_cache3 = _resize_cache(intermediate_cache3,
(curr_M, K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qcurr_hidden_states, a1q_scale), (w1, w1_scale),
intermediate_cache1, expert_ids)
if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu":
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = _fp8_quantize(
intermediate_cache2, a2_scale, block_shape)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(qintermediate_cache2, a2q_scale), (w2, w2_scale),
intermediate_cache3, expert_ids)
_moe_unpermute_and_reduce(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
intermediate_cache3.view(*intermediate_cache3.shape), inv_perm,
top_k_num, K, curr_topk_weights)
return out_hidden_states
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
def cutlass_moe_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
ab_strides1: torch.Tensor,
c_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides2: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
using two sets of quantized weights, w1_q and w2_q, and top-k gating
mechanism. The matrix multiplications are implemented with CUTLASS
grouped gemm.
Parameters:
- a (torch.Tensor): The input tensor to the MoE layer.
Shape: [M, K]
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
Shape: [num_experts, K, 2N] (the weights are passed transposed)
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
Shape: [num_experts, N, K] (the weights are passed transposed)
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
Shape: [num_experts] or [num_experts, 2N]
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
Shape: [num_experts] or [num_experts, K]
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
- ab_strides1 (torch.Tensor): The input and weights strides of the first
grouped gemm.
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
- ab_strides2 (torch.Tensor): The input and weights strides of the second
grouped gemm.
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
Shape: scalar or [M]
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
quantize the intermediate result between the gemms.
Shape: scalar or [M]
- out_dtype (torch.Tensor): The output tensor type.
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert a1_scale is None or a1_scale.dim(
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
0], "Input scale shape mismatch"
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
1] == w1_q.shape[2], "W1 scale shape mismatch"
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
1] == w2_q.shape[2], "W2 scale shape mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[
0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[
0], "w2 scales expert number mismatch"
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
assert ab_strides1.shape[0] == w1_q.shape[
0], "AB Strides 1 expert number mismatch"
assert c_strides1.shape[0] == w1_q.shape[
0], "C Strides 1 expert number mismatch"
assert ab_strides2.shape[0] == w2_q.shape[
0], "AB Strides 2 expert number mismatch"
assert c_strides2.shape[0] == w2_q.shape[
0], "C Strides 2 expert number mismatch"
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
topk = topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
device = a_q.device
expert_offsets = torch.empty((num_experts + 1),
dtype=torch.int32,
device=device)
problem_sizes1 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
problem_sizes2 = torch.empty((num_experts, 3),
dtype=torch.int32,
device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
problem_sizes2, a_map, c_map, num_experts, n,
k)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
expert_offsets[:-1], problem_sizes1, ab_strides1,
ab_strides1, c_strides1)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(intermediate, c1)
intemediate_q, a2_scale = ops.scaled_fp8_quant(
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
return (c2[c_map].view(m, topk, k) *
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)

View File

@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.utils import round_up
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def moe_align_block_size(
topk_ids: torch.Tensor,
block_size: int,
num_experts: int,
expert_map: Optional[torch.Tensor] = None,
pad_sorted_ids: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
- expert_map: A tensor of shape [num_experts] that maps the expert index
from the global space to the local index space of the current
expert parallel shard. If the expert is not in the current expert
parallel shard, the mapping is set to -1.
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
should be padded to a multiple of block_size,
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad

View File

@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
from math import prod
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import cdiv
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
"""
Shrink the given tensor and apply the given view to it. This is
used to resize the intermediate fused_moe caches.
"""
assert prod(v) <= x.numel()
return x.flatten()[:prod(v)].view(*v)
def _fp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
block_shape: Optional[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform fp8 quantization on the inputs. If a block_shape
is provided, the output will be blocked.
"""
if block_shape is None:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
else:
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
return A, A_scale
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
"""
A permutation routine that works on fp8 types.
"""
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]