mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[Minor] Fused experts refactor (#15914)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
d2b58ca203
commit
15ba07ef25
@ -9,8 +9,11 @@ import torch
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
@ -437,7 +440,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
pytest.skip(
|
||||
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
||||
|
||||
if (N <= 512):
|
||||
if N <= 512:
|
||||
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
@ -4,8 +4,8 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
fused_topk)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -131,9 +131,9 @@ def test_cutlass_moe_no_graph(
|
||||
c_strides2,
|
||||
a1_scale=a_scale1)
|
||||
|
||||
print(triton_output)
|
||||
print(cutlass_output)
|
||||
print("*")
|
||||
#print(triton_output)
|
||||
#print(cutlass_output)
|
||||
#print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
@ -234,9 +234,9 @@ def test_cutlass_moe_cuda_graph(
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
print(triton_output)
|
||||
print(cutlass_output)
|
||||
print("*")
|
||||
#print(triton_output)
|
||||
#print(cutlass_output)
|
||||
#print("*")
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
|
||||
@ -35,9 +35,11 @@ if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
cutlass_moe_fp8, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
|
||||
144
vllm/model_executor/layers/fused_moe/cutlass_moe.py
Normal file
144
vllm/model_executor/layers/fused_moe/cutlass_moe.py
Normal file
@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- ab_strides1 (torch.Tensor): The input and weights strides of the first
|
||||
grouped gemm.
|
||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||
- ab_strides2 (torch.Tensor): The input and weights strides of the second
|
||||
grouped gemm.
|
||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- out_dtype (torch.Tensor): The output tensor type.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.float8_e4m3fn
|
||||
assert w2_q.dtype == torch.float8_e4m3fn
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
||||
0], "Input scale shape mismatch"
|
||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[
|
||||
0], "w1 scales expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[
|
||||
0], "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
assert ab_strides1.shape[0] == w1_q.shape[
|
||||
0], "AB Strides 1 expert number mismatch"
|
||||
assert c_strides1.shape[0] == w1_q.shape[
|
||||
0], "C Strides 1 expert number mismatch"
|
||||
assert ab_strides2.shape[0] == w2_q.shape[
|
||||
0], "AB Strides 2 expert number mismatch"
|
||||
assert c_strides2.shape[0] == w2_q.shape[
|
||||
0], "C Strides 2 expert number mismatch"
|
||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
|
||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||
device = a_q.device
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes2 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, num_experts, n,
|
||||
k)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||
ab_strides1, c_strides1)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
||||
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||
ab_strides2, c_strides2)
|
||||
|
||||
return (c2[c_map].view(m, topk, k) *
|
||||
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
294
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Normal file
294
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Normal file
@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import importlib.util
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||
_fp8_quantize,
|
||||
_resize_cache)
|
||||
from vllm.utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the DeepGemm grouped
|
||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||
"""
|
||||
if not has_deep_gemm:
|
||||
return False
|
||||
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
# Expert maps not supported yet.
|
||||
if expert_map is not None:
|
||||
return False
|
||||
|
||||
align = dg.get_m_alignment_for_contiguous_layout()
|
||||
M = hidden_states.shape[0]
|
||||
_, K, N = w2.shape
|
||||
|
||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||
# ops are available.
|
||||
if N <= 512:
|
||||
return False
|
||||
|
||||
if align > M or N % align != 0 or K % align != 0:
|
||||
return False
|
||||
|
||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||
and w2.is_contiguous())
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
block_m: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
top_k_num = curr_topk_ids.shape[1]
|
||||
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
block_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
pad_sorted_ids=True))
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm)
|
||||
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
topk_weight: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M, topk = topk_weight.shape
|
||||
K = curr_hidden.shape[1]
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with DeepGemm
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
assert expert_map is None, "Expert maps not supported yet"
|
||||
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||
block_shape = [block_m, block_m]
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||
# case these calls will be nops. Otherwise, they'll be performed every
|
||||
# time the layer is executed.
|
||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||
|
||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
|
||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.empty(M_sum * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
|
||||
workspace2 = torch.empty((M_sum, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||
a1_scale, block_shape)
|
||||
|
||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||
curr_topk_ids, global_num_experts,
|
||||
expert_map, block_m)
|
||||
|
||||
# Adjust the intermediate cache size and config for the last chunk.
|
||||
# Note that in most cases we only have one chunk so the cache size
|
||||
# and config are already set correctly and do not need to be adjusted.
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
curr_M = sorted_token_ids.numel()
|
||||
workspace1 = _resize_cache(workspace1, (curr_M, N))
|
||||
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
|
||||
workspace3 = _resize_cache(workspace3, (curr_M, K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
|
||||
expert_ids)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
|
||||
block_shape)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||
|
||||
_moe_unpermute_and_reduce(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
|
||||
|
||||
return out_hidden_states
|
||||
@ -1,10 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Fused MoE kernel."""
|
||||
import functools
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
from math import prod
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -14,10 +12,13 @@ import triton.language as tl
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op, round_up
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
rocm_aiter_fused_experts,
|
||||
@ -25,8 +26,6 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||
@ -443,300 +442,13 @@ def fused_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
to their allocated expert.
|
||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||
- num_tokens_post_padded: The total number of tokens after padding,
|
||||
ensuring divisibility by block_size.
|
||||
|
||||
This function pads the number of tokens that each expert needs to process
|
||||
so that it is divisible by block_size.
|
||||
Padding ensures that during block matrix multiplication, the dimensions
|
||||
align correctly.
|
||||
|
||||
Example:
|
||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||
block_size = 4, and num_experts = 4:
|
||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||
with each expert needing to process 3 tokens.
|
||||
- As block_size is 4, we pad 1 token for each expert.
|
||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||
- After sorting by expert index, we obtain token_ids
|
||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||
Tokens 12 are non-existent (padding) and are ignored in
|
||||
the subsequent matrix multiplication.
|
||||
- The padding ensures that the total number of tokens is now divisible
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
if num_experts >= 224:
|
||||
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
# Currently requires num_experts=256
|
||||
ops.sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor]) -> bool:
|
||||
"""
|
||||
Check if the given problem size is supported by the DeepGemm grouped
|
||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||
"""
|
||||
if not has_deep_gemm:
|
||||
return False
|
||||
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
# Expert maps not supported yet.
|
||||
if expert_map is not None:
|
||||
return False
|
||||
|
||||
align = dg.get_m_alignment_for_contiguous_layout()
|
||||
M = hidden_states.shape[0]
|
||||
_, K, N = w2.shape
|
||||
|
||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||
# ops are available.
|
||||
if N <= 512:
|
||||
return False
|
||||
|
||||
if align > M or N % align != 0 or K % align != 0:
|
||||
return False
|
||||
|
||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||
and w2.is_contiguous())
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: Optional[torch.Tensor],
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
@ -748,7 +460,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8:
|
||||
@ -765,6 +478,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.shape[0]
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.shape[0]
|
||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
@ -782,7 +498,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=topk_ids.numel(),
|
||||
num_valid_tokens=num_tokens,
|
||||
group_size=block_shape[1],
|
||||
num_experts=B.shape[0],
|
||||
bit=4 if use_int4_w4a16 else 8)
|
||||
@ -790,12 +506,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
config.update(
|
||||
get_moe_wna16_block_config(config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=topk_ids.numel(),
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.shape[1],
|
||||
size_n=B.shape[1],
|
||||
num_experts=B.shape[1],
|
||||
group_size=block_shape[1],
|
||||
real_top_k=topk_ids.shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"]))
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
@ -821,7 +537,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
@ -864,7 +580,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
@ -1389,6 +1105,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
@ -1419,85 +1136,6 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
A permutation routine that works on fp8 types.
|
||||
"""
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def _moe_permute(
|
||||
curr_hidden_states: torch.Tensor,
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
curr_topk_ids: torch.Tensor,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
top_k_num: int,
|
||||
block_m: int,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
"""
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(curr_topk_ids,
|
||||
block_m,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
pad_sorted_ids=True))
|
||||
|
||||
inv_perm: Optional[torch.Tensor] = None
|
||||
|
||||
num_tokens = top_k_num * tokens_in_chunk
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||
|
||||
# Permute according to sorted token ids.
|
||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||
sorted_token_ids // top_k_num)
|
||||
|
||||
if a1q_scale is not None:
|
||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||
|
||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm)
|
||||
|
||||
|
||||
def _moe_unpermute_and_reduce(
|
||||
out: torch.Tensor,
|
||||
curr_hidden: torch.Tensor,
|
||||
inv_perm: Optional[torch.Tensor],
|
||||
topk: int,
|
||||
K: int,
|
||||
topk_weight: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Unpermute the final result and apply topk_weights, then perform the final
|
||||
reduction on the hidden states.
|
||||
"""
|
||||
M = topk_weight.shape[0]
|
||||
curr_hidden = curr_hidden[inv_perm, ...]
|
||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||
ops.moe_sum(curr_hidden, out)
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(v) <= x.numel()
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -1629,7 +1267,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
@ -1660,28 +1297,34 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape)
|
||||
invoke_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False, #True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
if True:
|
||||
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
|
||||
intermediate_cache3.mul_(
|
||||
curr_topk_weights.view(tokens_in_chunk, -1, 1))
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
@ -1790,327 +1433,3 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with DeepGemm
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- activation (str): The activation function to apply after the first
|
||||
MoE layer.
|
||||
- global_num_experts (int): The total number of experts in the global
|
||||
expert space.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Lazy import to avoid CUDA initialization problems.
|
||||
import deep_gemm as dg
|
||||
|
||||
assert expert_map is None, "Expert maps not supported yet"
|
||||
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert w1.dtype == torch.float8_e4m3fn
|
||||
assert w2.dtype == torch.float8_e4m3fn
|
||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
K = w2.shape[1]
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
top_k_num = topk_ids.shape[1]
|
||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
|
||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||
|
||||
if inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||
block_shape = [block_m, block_m]
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
|
||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||
# case these calls will be nops. Otherwise, they'll be performed every
|
||||
# time the layer is executed.
|
||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||
|
||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
|
||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
cache13 = torch.empty(M_sum * max(N, K),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N)
|
||||
intermediate_cache2 = torch.empty((M_sum, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K)
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||
min((chunk + 1) * CHUNK_SIZE,
|
||||
num_tokens))
|
||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||
|
||||
if tokens_in_chunk == 0:
|
||||
break
|
||||
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
a1q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||
a1_scale, block_shape)
|
||||
|
||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||
curr_topk_ids, global_num_experts,
|
||||
expert_map, top_k_num, block_m)
|
||||
|
||||
# Adjust the intermediate cache size and config for the last chunk.
|
||||
# Note that in most cases we only have one chunk so the cache size
|
||||
# and config are already set correctly and do not need to be adjusted.
|
||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||
curr_M = sorted_token_ids.numel()
|
||||
intermediate_cache1 = _resize_cache(intermediate_cache1,
|
||||
(curr_M, N))
|
||||
intermediate_cache2 = _resize_cache(intermediate_cache2,
|
||||
(curr_M, N // 2))
|
||||
intermediate_cache3 = _resize_cache(intermediate_cache3,
|
||||
(curr_M, K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale),
|
||||
intermediate_cache1, expert_ids)
|
||||
|
||||
if activation == "silu":
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
elif activation == "gelu":
|
||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
||||
intermediate_cache1.view(-1, N))
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
||||
intermediate_cache2, a2_scale, block_shape)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(qintermediate_cache2, a2q_scale), (w2, w2_scale),
|
||||
intermediate_cache3, expert_ids)
|
||||
|
||||
_moe_unpermute_and_reduce(
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
intermediate_cache3.view(*intermediate_cache3.shape), inv_perm,
|
||||
top_k_num, K, curr_topk_weights)
|
||||
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||
def cutlass_moe_fp8(
|
||||
a: torch.Tensor,
|
||||
w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
out_dtype: torch.dtype = torch.half,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||
grouped gemm.
|
||||
|
||||
Parameters:
|
||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||
Shape: [M, K]
|
||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||
Shape: [num_experts] or [num_experts, 2N]
|
||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||
Shape: [num_experts] or [num_experts, K]
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||
- ab_strides1 (torch.Tensor): The input and weights strides of the first
|
||||
grouped gemm.
|
||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||
- ab_strides2 (torch.Tensor): The input and weights strides of the second
|
||||
grouped gemm.
|
||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||
Shape: scalar or [M]
|
||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||
quantize the intermediate result between the gemms.
|
||||
Shape: scalar or [M]
|
||||
- out_dtype (torch.Tensor): The output tensor type.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert w1_q.dtype == torch.float8_e4m3fn
|
||||
assert w2_q.dtype == torch.float8_e4m3fn
|
||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||
assert a1_scale is None or a1_scale.dim(
|
||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
||||
0], "Input scale shape mismatch"
|
||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
||||
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
||||
assert w1_q.shape[0] == w1_scale.shape[
|
||||
0], "w1 scales expert number mismatch"
|
||||
assert w1_q.shape[0] == w2_scale.shape[
|
||||
0], "w2 scales expert number mismatch"
|
||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||
assert ab_strides1.shape[0] == w1_q.shape[
|
||||
0], "AB Strides 1 expert number mismatch"
|
||||
assert c_strides1.shape[0] == w1_q.shape[
|
||||
0], "C Strides 1 expert number mismatch"
|
||||
assert ab_strides2.shape[0] == w2_q.shape[
|
||||
0], "AB Strides 2 expert number mismatch"
|
||||
assert c_strides2.shape[0] == w2_q.shape[
|
||||
0], "C Strides 2 expert number mismatch"
|
||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||
|
||||
num_experts = w1_q.size(0)
|
||||
m = a.size(0)
|
||||
k = w1_q.size(1)
|
||||
n = w2_q.size(1)
|
||||
|
||||
topk = topk_ids.size(1)
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
|
||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||
device = a_q.device
|
||||
|
||||
expert_offsets = torch.empty((num_experts + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes1 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
problem_sizes2 = torch.empty((num_experts, 3),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||
|
||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, a_map, c_map, num_experts, n,
|
||||
k)
|
||||
|
||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||
|
||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||
ab_strides1, c_strides1)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||
|
||||
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
||||
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||
ab_strides2, c_strides2)
|
||||
|
||||
return (c2[c_map].view(m, topk, k) *
|
||||
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||
|
||||
243
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
243
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
@ -0,0 +1,243 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = ceil_div(numel, num_experts)
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
pad_sorted_ids: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
|
||||
Parameters:
|
||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||
top-k expert indices for each token.
|
||||
- block_size: The block size used in block matrix multiplication.
|
||||
- num_experts: The total number of experts.
|
||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||
from the global space to the local index space of the current
|
||||
expert parallel shard. If the expert is not in the current expert
|
||||
parallel shard, the mapping is set to -1.
|
||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||
should be padded to a multiple of block_size,
|
||||
|
||||
Returns:
|
||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||
to their allocated expert.
|
||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||
- num_tokens_post_padded: The total number of tokens after padding,
|
||||
ensuring divisibility by block_size.
|
||||
|
||||
This function pads the number of tokens that each expert needs to process
|
||||
so that it is divisible by block_size.
|
||||
Padding ensures that during block matrix multiplication, the dimensions
|
||||
align correctly.
|
||||
|
||||
Example:
|
||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||
block_size = 4, and num_experts = 4:
|
||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||
with each expert needing to process 3 tokens.
|
||||
- As block_size is 4, we pad 1 token for each expert.
|
||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||
- After sorting by expert index, we obtain token_ids
|
||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||
Tokens 12 are non-existent (padding) and are ignored in
|
||||
the subsequent matrix multiplication.
|
||||
- The padding ensures that the total number of tokens is now divisible
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||
# mapping global expert ids to local expert ids in expert parallelism.
|
||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
if num_experts >= 224:
|
||||
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
# Currently requires num_experts=256
|
||||
ops.sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
else:
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
48
vllm/model_executor/layers/fused_moe/utils.py
Normal file
48
vllm/model_executor/layers/fused_moe/utils.py
Normal file
@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from math import prod
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
used to resize the intermediate fused_moe caches.
|
||||
"""
|
||||
assert prod(v) <= x.numel()
|
||||
return x.flatten()[:prod(v)].view(*v)
|
||||
|
||||
|
||||
def _fp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Perform fp8 quantization on the inputs. If a block_shape
|
||||
is provided, the output will be blocked.
|
||||
"""
|
||||
if block_shape is None:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
else:
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
return A, A_scale
|
||||
|
||||
|
||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
A permutation routine that works on fp8 types.
|
||||
"""
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
Loading…
x
Reference in New Issue
Block a user