mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Minor] Fused experts refactor (#15914)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
d2b58ca203
commit
15ba07ef25
@ -9,8 +9,11 @@ import torch
|
|||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
deep_gemm_moe_fp8, fused_topk, moe_align_block_size)
|
deep_gemm_moe_fp8)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
|
moe_align_block_size)
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -437,7 +440,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
|||||||
pytest.skip(
|
pytest.skip(
|
||||||
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
||||||
|
|
||||||
if (N <= 512):
|
if N <= 512:
|
||||||
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
||||||
|
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
|
|||||||
@ -4,8 +4,8 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||||
fused_experts,
|
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||||
fused_topk)
|
fused_topk)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@ -131,9 +131,9 @@ def test_cutlass_moe_no_graph(
|
|||||||
c_strides2,
|
c_strides2,
|
||||||
a1_scale=a_scale1)
|
a1_scale=a_scale1)
|
||||||
|
|
||||||
print(triton_output)
|
#print(triton_output)
|
||||||
print(cutlass_output)
|
#print(cutlass_output)
|
||||||
print("*")
|
#print("*")
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output,
|
torch.testing.assert_close(triton_output,
|
||||||
cutlass_output,
|
cutlass_output,
|
||||||
@ -234,9 +234,9 @@ def test_cutlass_moe_cuda_graph(
|
|||||||
graph.replay()
|
graph.replay()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
print(triton_output)
|
#print(triton_output)
|
||||||
print(cutlass_output)
|
#print(cutlass_output)
|
||||||
print("*")
|
#print("*")
|
||||||
|
|
||||||
torch.testing.assert_close(triton_output,
|
torch.testing.assert_close(triton_output,
|
||||||
cutlass_output,
|
cutlass_output,
|
||||||
|
|||||||
@ -35,9 +35,11 @@ if HAS_TRITON:
|
|||||||
# import to register the custom ops
|
# import to register the custom ops
|
||||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
cutlass_moe_fp8, fused_experts, fused_moe, fused_topk,
|
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||||
get_config_file_name, grouped_topk)
|
grouped_topk)
|
||||||
|
|
||||||
__all__ += [
|
__all__ += [
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
|
|||||||
144
vllm/model_executor/layers/fused_moe/cutlass_moe.py
Normal file
144
vllm/model_executor/layers/fused_moe/cutlass_moe.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Fused MoE kernel."""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
|
||||||
|
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
||||||
|
def cutlass_moe_fp8(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
out_dtype: torch.dtype = torch.half,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
|
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||||
|
mechanism. The matrix multiplications are implemented with CUTLASS
|
||||||
|
grouped gemm.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- a (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
Shape: [M, K]
|
||||||
|
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
||||||
|
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||||
|
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
||||||
|
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||||
|
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||||
|
Shape: [num_experts] or [num_experts, 2N]
|
||||||
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
|
Shape: [num_experts] or [num_experts, K]
|
||||||
|
- gating_output (torch.Tensor): The output of the gating operation
|
||||||
|
(before softmax).
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- ab_strides1 (torch.Tensor): The input and weights strides of the first
|
||||||
|
grouped gemm.
|
||||||
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
||||||
|
- ab_strides2 (torch.Tensor): The input and weights strides of the second
|
||||||
|
grouped gemm.
|
||||||
|
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
||||||
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
|
Shape: scalar or [M]
|
||||||
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
|
quantize the intermediate result between the gemms.
|
||||||
|
Shape: scalar or [M]
|
||||||
|
- out_dtype (torch.Tensor): The output tensor type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert w1_q.dtype == torch.float8_e4m3fn
|
||||||
|
assert w2_q.dtype == torch.float8_e4m3fn
|
||||||
|
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
||||||
|
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
||||||
|
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
||||||
|
assert a1_scale is None or a1_scale.dim(
|
||||||
|
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
||||||
|
0], "Input scale shape mismatch"
|
||||||
|
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
||||||
|
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
||||||
|
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
||||||
|
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
||||||
|
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
||||||
|
assert w1_q.shape[0] == w1_scale.shape[
|
||||||
|
0], "w1 scales expert number mismatch"
|
||||||
|
assert w1_q.shape[0] == w2_scale.shape[
|
||||||
|
0], "w2 scales expert number mismatch"
|
||||||
|
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||||
|
assert ab_strides1.shape[0] == w1_q.shape[
|
||||||
|
0], "AB Strides 1 expert number mismatch"
|
||||||
|
assert c_strides1.shape[0] == w1_q.shape[
|
||||||
|
0], "C Strides 1 expert number mismatch"
|
||||||
|
assert ab_strides2.shape[0] == w2_q.shape[
|
||||||
|
0], "AB Strides 2 expert number mismatch"
|
||||||
|
assert c_strides2.shape[0] == w2_q.shape[
|
||||||
|
0], "C Strides 2 expert number mismatch"
|
||||||
|
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
||||||
|
|
||||||
|
num_experts = w1_q.size(0)
|
||||||
|
m = a.size(0)
|
||||||
|
k = w1_q.size(1)
|
||||||
|
n = w2_q.size(1)
|
||||||
|
|
||||||
|
topk = topk_ids.size(1)
|
||||||
|
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||||
|
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||||
|
|
||||||
|
a_q, a1_scale = ops.scaled_fp8_quant(
|
||||||
|
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
||||||
|
device = a_q.device
|
||||||
|
|
||||||
|
expert_offsets = torch.empty((num_experts + 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes1 = torch.empty((num_experts, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes2 = torch.empty((num_experts, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
||||||
|
problem_sizes2, a_map, c_map, num_experts, n,
|
||||||
|
k)
|
||||||
|
|
||||||
|
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
||||||
|
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
||||||
|
|
||||||
|
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
||||||
|
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
||||||
|
|
||||||
|
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
||||||
|
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
||||||
|
ab_strides1, c_strides1)
|
||||||
|
|
||||||
|
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
||||||
|
torch.ops._C.silu_and_mul(intermediate, c1)
|
||||||
|
|
||||||
|
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
||||||
|
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
||||||
|
|
||||||
|
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
||||||
|
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
||||||
|
ab_strides2, c_strides2)
|
||||||
|
|
||||||
|
return (c2[c_map].view(m, topk, k) *
|
||||||
|
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
||||||
294
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Normal file
294
vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Normal file
@ -0,0 +1,294 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import importlib.util
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
|
moe_align_block_size)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||||
|
_fp8_quantize,
|
||||||
|
_resize_cache)
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_deep_gemm(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
expert_map: Optional[torch.Tensor] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given problem size is supported by the DeepGemm grouped
|
||||||
|
gemm kernel. All of M, N, K and the quantization block_shape must be
|
||||||
|
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
||||||
|
"""
|
||||||
|
if not has_deep_gemm:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
|
import deep_gemm as dg
|
||||||
|
|
||||||
|
# Expert maps not supported yet.
|
||||||
|
if expert_map is not None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
align = dg.get_m_alignment_for_contiguous_layout()
|
||||||
|
M = hidden_states.shape[0]
|
||||||
|
_, K, N = w2.shape
|
||||||
|
|
||||||
|
# For now, disable DeepGemm for small N until better permute/unpermute
|
||||||
|
# ops are available.
|
||||||
|
if N <= 512:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if align > M or N % align != 0 or K % align != 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
||||||
|
and w2.is_contiguous())
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_permute(
|
||||||
|
curr_hidden_states: torch.Tensor,
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
curr_topk_ids: torch.Tensor,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
block_m: int,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||||
|
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||||
|
"""
|
||||||
|
top_k_num = curr_topk_ids.shape[1]
|
||||||
|
|
||||||
|
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||||
|
|
||||||
|
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||||
|
moe_align_block_size(curr_topk_ids,
|
||||||
|
block_m,
|
||||||
|
global_num_experts,
|
||||||
|
expert_map,
|
||||||
|
pad_sorted_ids=True))
|
||||||
|
|
||||||
|
inv_perm: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
num_tokens = top_k_num * tokens_in_chunk
|
||||||
|
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||||
|
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
||||||
|
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
||||||
|
|
||||||
|
# Permute according to sorted token ids.
|
||||||
|
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
||||||
|
sorted_token_ids // top_k_num)
|
||||||
|
|
||||||
|
if a1q_scale is not None:
|
||||||
|
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
||||||
|
|
||||||
|
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
|
inv_perm)
|
||||||
|
|
||||||
|
|
||||||
|
def _moe_unpermute_and_reduce(
|
||||||
|
out: torch.Tensor,
|
||||||
|
curr_hidden: torch.Tensor,
|
||||||
|
inv_perm: Optional[torch.Tensor],
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Unpermute the final result and apply topk_weights, then perform the final
|
||||||
|
reduction on the hidden states.
|
||||||
|
"""
|
||||||
|
M, topk = topk_weight.shape
|
||||||
|
K = curr_hidden.shape[1]
|
||||||
|
curr_hidden = curr_hidden[inv_perm, ...]
|
||||||
|
curr_hidden = curr_hidden.view(-1, topk, K)
|
||||||
|
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
||||||
|
ops.moe_sum(curr_hidden, out)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_gemm_moe_fp8(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
||||||
|
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
||||||
|
mechanism. The matrix multiplications are implemented with DeepGemm
|
||||||
|
grouped gemm.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||||
|
Shape: [M, K]
|
||||||
|
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
||||||
|
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
||||||
|
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
||||||
|
Shape: [num_experts, N, K] (the weights are passed transposed)
|
||||||
|
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
||||||
|
Shape: [num_experts] or [num_experts, 2N]
|
||||||
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
||||||
|
Shape: [num_experts] or [num_experts, K]
|
||||||
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
||||||
|
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
||||||
|
- inplace (bool): If True, perform the operation in-place.
|
||||||
|
Defaults to False.
|
||||||
|
- activation (str): The activation function to apply after the first
|
||||||
|
MoE layer.
|
||||||
|
- global_num_experts (int): The total number of experts in the global
|
||||||
|
expert space.
|
||||||
|
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||||
|
from the global expert space to the local expert space of the expert
|
||||||
|
parallel shard.
|
||||||
|
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
||||||
|
Shape: scalar or [M]
|
||||||
|
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
||||||
|
quantize the intermediate result between the gemms.
|
||||||
|
Shape: scalar or [M]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid CUDA initialization problems.
|
||||||
|
import deep_gemm as dg
|
||||||
|
|
||||||
|
assert expert_map is None, "Expert maps not supported yet"
|
||||||
|
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
|
|
||||||
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
|
assert hidden_states.dtype in [
|
||||||
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
|
]
|
||||||
|
assert w1.dtype == torch.float8_e4m3fn
|
||||||
|
assert w2.dtype == torch.float8_e4m3fn
|
||||||
|
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
||||||
|
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
||||||
|
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
||||||
|
assert a1_scale is None or a1_scale.dim(
|
||||||
|
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
||||||
|
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
||||||
|
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
||||||
|
|
||||||
|
num_tokens, _ = hidden_states.shape
|
||||||
|
E, N, _ = w1.shape
|
||||||
|
K = w2.shape[1]
|
||||||
|
if global_num_experts == -1:
|
||||||
|
global_num_experts = E
|
||||||
|
|
||||||
|
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
|
|
||||||
|
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
||||||
|
|
||||||
|
if inplace:
|
||||||
|
out_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
out_hidden_states = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
block_m = dg.get_m_alignment_for_contiguous_layout()
|
||||||
|
block_shape = [block_m, block_m]
|
||||||
|
|
||||||
|
assert w1_scale is not None
|
||||||
|
assert w2_scale is not None
|
||||||
|
|
||||||
|
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
||||||
|
# case these calls will be nops. Otherwise, they'll be performed every
|
||||||
|
# time the layer is executed.
|
||||||
|
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
||||||
|
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
||||||
|
|
||||||
|
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
||||||
|
M_sum = round_up(M_sum, block_m)
|
||||||
|
|
||||||
|
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
||||||
|
|
||||||
|
# We can reuse the memory between cache1 and cache3 because by the time
|
||||||
|
# we need cache3, we're done with cache1
|
||||||
|
workspace13 = torch.empty(M_sum * max(N, K),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
|
workspace1 = workspace13[:M_sum * N].view(M_sum, N)
|
||||||
|
workspace2 = torch.empty((M_sum, N // 2),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
workspace3 = workspace13[:M_sum * K].view(M_sum, K)
|
||||||
|
|
||||||
|
for chunk in range(num_chunks):
|
||||||
|
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
||||||
|
min((chunk + 1) * CHUNK_SIZE,
|
||||||
|
num_tokens))
|
||||||
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||||
|
tokens_in_chunk, _ = curr_hidden_states.shape
|
||||||
|
|
||||||
|
if tokens_in_chunk == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||||
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
||||||
|
|
||||||
|
a1q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
||||||
|
a1_scale, block_shape)
|
||||||
|
|
||||||
|
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
||||||
|
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
||||||
|
curr_topk_ids, global_num_experts,
|
||||||
|
expert_map, block_m)
|
||||||
|
|
||||||
|
# Adjust the intermediate cache size and config for the last chunk.
|
||||||
|
# Note that in most cases we only have one chunk so the cache size
|
||||||
|
# and config are already set correctly and do not need to be adjusted.
|
||||||
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
||||||
|
curr_M = sorted_token_ids.numel()
|
||||||
|
workspace1 = _resize_cache(workspace1, (curr_M, N))
|
||||||
|
workspace2 = _resize_cache(workspace2, (curr_M, N // 2))
|
||||||
|
workspace3 = _resize_cache(workspace3, (curr_M, K))
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(qcurr_hidden_states, a1q_scale), (w1, w1_scale), workspace1,
|
||||||
|
expert_ids)
|
||||||
|
|
||||||
|
if activation == "silu":
|
||||||
|
torch.ops._C.silu_and_mul(workspace2, workspace1.view(-1, N))
|
||||||
|
elif activation == "gelu":
|
||||||
|
torch.ops._C.gelu_and_mul(workspace2, workspace1.view(-1, N))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||||
|
|
||||||
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
qworkspace2, a2q_scale = _fp8_quantize(workspace2, a2_scale,
|
||||||
|
block_shape)
|
||||||
|
|
||||||
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
|
(qworkspace2, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||||
|
|
||||||
|
_moe_unpermute_and_reduce(
|
||||||
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||||
|
workspace3.view(*workspace3.shape), inv_perm, curr_topk_weights)
|
||||||
|
|
||||||
|
return out_hidden_states
|
||||||
@ -1,10 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Fused MoE kernel."""
|
"""Fused MoE kernel."""
|
||||||
import functools
|
import functools
|
||||||
import importlib.util
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from math import prod
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -14,10 +12,13 @@ import triton.language as tl
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
per_token_group_quant_fp8)
|
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
|
moe_align_block_size)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op, round_up
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
||||||
rocm_aiter_fused_experts,
|
rocm_aiter_fused_experts,
|
||||||
@ -25,8 +26,6 @@ from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
|
||||||
@ -443,300 +442,13 @@ def fused_moe_kernel(
|
|||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def ceil_div(a, b):
|
|
||||||
return (a + b - 1) // b
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage1(
|
|
||||||
topk_ids_ptr,
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
numel: tl.constexpr,
|
|
||||||
tokens_per_thread: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
|
|
||||||
start_idx = pid * tokens_per_thread
|
|
||||||
|
|
||||||
off_c = (pid + 1) * num_experts
|
|
||||||
|
|
||||||
for i in range(tokens_per_thread):
|
|
||||||
if start_idx + i < numel:
|
|
||||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
|
||||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage2(
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
|
|
||||||
last_cnt = 0
|
|
||||||
for i in range(1, num_experts + 1):
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
|
||||||
last_cnt = last_cnt + token_cnt
|
|
||||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage3(
|
|
||||||
total_tokens_post_pad_ptr,
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
cumsum_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
):
|
|
||||||
last_cumsum = 0
|
|
||||||
off_cnt = num_experts * num_experts
|
|
||||||
for i in range(1, num_experts + 1):
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
|
||||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
|
||||||
tl.store(cumsum_ptr + i, last_cumsum)
|
|
||||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage4(
|
|
||||||
topk_ids_ptr,
|
|
||||||
sorted_token_ids_ptr,
|
|
||||||
expert_ids_ptr,
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
cumsum_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
numel: tl.constexpr,
|
|
||||||
tokens_per_thread: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
start_idx = tl.load(cumsum_ptr + pid)
|
|
||||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
|
||||||
|
|
||||||
for i in range(start_idx, end_idx, block_size):
|
|
||||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
|
||||||
|
|
||||||
start_idx = pid * tokens_per_thread
|
|
||||||
off_t = pid * num_experts
|
|
||||||
|
|
||||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
|
||||||
numel)):
|
|
||||||
expert_id = tl.load(topk_ids_ptr + i)
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
|
||||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
|
||||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
|
||||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Triton implementation based on:
|
|
||||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
|
||||||
def moe_align_block_size_triton(
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
block_size: int,
|
|
||||||
sorted_token_ids: torch.Tensor,
|
|
||||||
expert_ids: torch.Tensor,
|
|
||||||
num_tokens_post_pad: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
numel = topk_ids.numel()
|
|
||||||
grid = (num_experts, )
|
|
||||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device)
|
|
||||||
cumsum = torch.zeros((num_experts + 1, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device)
|
|
||||||
tokens_per_thread = ceil_div(numel, num_experts)
|
|
||||||
|
|
||||||
moe_align_block_size_stage1[grid](
|
|
||||||
topk_ids,
|
|
||||||
tokens_cnts,
|
|
||||||
num_experts,
|
|
||||||
numel,
|
|
||||||
tokens_per_thread,
|
|
||||||
)
|
|
||||||
moe_align_block_size_stage2[grid](
|
|
||||||
tokens_cnts,
|
|
||||||
num_experts,
|
|
||||||
)
|
|
||||||
moe_align_block_size_stage3[(1, )](
|
|
||||||
num_tokens_post_pad,
|
|
||||||
tokens_cnts,
|
|
||||||
cumsum,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
)
|
|
||||||
moe_align_block_size_stage4[grid](
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
tokens_cnts,
|
|
||||||
cumsum,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
numel,
|
|
||||||
tokens_per_thread,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
block_size: int,
|
|
||||||
num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
pad_sorted_ids: bool = False
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Aligns the token distribution across experts to be compatible with block
|
|
||||||
size for matrix multiplication.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
|
||||||
top-k expert indices for each token.
|
|
||||||
- block_size: The block size used in block matrix multiplication.
|
|
||||||
- num_experts: The total number of experts.
|
|
||||||
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
|
||||||
from the global space to the local index space of the current
|
|
||||||
expert parallel shard. If the expert is not in the current expert
|
|
||||||
parallel shard, the mapping is set to -1.
|
|
||||||
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
|
||||||
should be padded to a multiple of block_size,
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- sorted_token_ids: A tensor containing the sorted token indices according
|
|
||||||
to their allocated expert.
|
|
||||||
- expert_ids: A tensor indicating the assigned expert index for each block.
|
|
||||||
- num_tokens_post_padded: The total number of tokens after padding,
|
|
||||||
ensuring divisibility by block_size.
|
|
||||||
|
|
||||||
This function pads the number of tokens that each expert needs to process
|
|
||||||
so that it is divisible by block_size.
|
|
||||||
Padding ensures that during block matrix multiplication, the dimensions
|
|
||||||
align correctly.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
|
||||||
block_size = 4, and num_experts = 4:
|
|
||||||
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
|
||||||
with each expert needing to process 3 tokens.
|
|
||||||
- As block_size is 4, we pad 1 token for each expert.
|
|
||||||
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
|
||||||
- Then append padding tokens [12, 12, 12, 12] for each block.
|
|
||||||
- After sorting by expert index, we obtain token_ids
|
|
||||||
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
|
||||||
Tokens 12 are non-existent (padding) and are ignored in
|
|
||||||
the subsequent matrix multiplication.
|
|
||||||
- The padding ensures that the total number of tokens is now divisible
|
|
||||||
by block_size for proper block matrix operations.
|
|
||||||
"""
|
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
|
||||||
if pad_sorted_ids:
|
|
||||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
|
||||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device)
|
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
|
||||||
# Expert ids must be zeroed out to prevent index out of bounds error while
|
|
||||||
# mapping global expert ids to local expert ids in expert parallelism.
|
|
||||||
expert_ids = torch.zeros((max_num_m_blocks, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device)
|
|
||||||
num_tokens_post_pad = torch.empty((1),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device)
|
|
||||||
if num_experts >= 224:
|
|
||||||
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
|
|
||||||
moe_align_block_size_triton(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Currently requires num_experts=256
|
|
||||||
ops.sgl_moe_align_block_size(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
|
||||||
expert_ids, num_tokens_post_pad)
|
|
||||||
if expert_map is not None:
|
|
||||||
expert_ids = expert_map[expert_ids]
|
|
||||||
|
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
|
||||||
|
|
||||||
|
|
||||||
def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
expert_map: Optional[torch.Tensor]) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the given problem size is supported by the DeepGemm grouped
|
|
||||||
gemm kernel. All of M, N, K and the quantization block_shape must be
|
|
||||||
aligned by `dg.get_m_alignment_for_contiguous_layout()`.
|
|
||||||
"""
|
|
||||||
if not has_deep_gemm:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
|
||||||
import deep_gemm as dg
|
|
||||||
|
|
||||||
# Expert maps not supported yet.
|
|
||||||
if expert_map is not None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
align = dg.get_m_alignment_for_contiguous_layout()
|
|
||||||
M = hidden_states.shape[0]
|
|
||||||
_, K, N = w2.shape
|
|
||||||
|
|
||||||
# For now, disable DeepGemm for small N until better permute/unpermute
|
|
||||||
# ops are available.
|
|
||||||
if N <= 512:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if align > M or N % align != 0 or K % align != 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (hidden_states.is_contiguous() and w1.is_contiguous()
|
|
||||||
and w2.is_contiguous())
|
|
||||||
|
|
||||||
|
|
||||||
def _fp8_quantize(
|
|
||||||
A: torch.Tensor,
|
|
||||||
A_scale: Optional[torch.Tensor],
|
|
||||||
block_shape: Optional[List[int]],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Perform fp8 quantization on the inputs. If a block_shape
|
|
||||||
is provided, the output will be blocked.
|
|
||||||
"""
|
|
||||||
if block_shape is None:
|
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
|
||||||
else:
|
|
||||||
assert len(block_shape) == 2
|
|
||||||
_, block_k = block_shape[0], block_shape[1]
|
|
||||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
|
||||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
|
||||||
return A, A_scale
|
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(A: torch.Tensor,
|
def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
B_scale: Optional[torch.Tensor],
|
B_scale: Optional[torch.Tensor],
|
||||||
B_zp: Optional[torch.Tensor],
|
B_zp: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: Optional[torch.Tensor],
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
expert_ids: torch.Tensor,
|
expert_ids: torch.Tensor,
|
||||||
num_tokens_post_padded: torch.Tensor,
|
num_tokens_post_padded: torch.Tensor,
|
||||||
@ -748,7 +460,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights is not None or not mul_routed_weight
|
||||||
|
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
@ -765,6 +478,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
assert A_scale is None
|
assert A_scale is None
|
||||||
assert B_scale is None
|
assert B_scale is None
|
||||||
|
|
||||||
|
M = A.shape[0]
|
||||||
|
num_tokens = M * top_k
|
||||||
|
|
||||||
EM = sorted_token_ids.shape[0]
|
EM = sorted_token_ids.shape[0]
|
||||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||||
# optimize for small batch_size.
|
# optimize for small batch_size.
|
||||||
@ -782,7 +498,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
assert B_zp is None or B_zp.ndim == 3
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
|
|
||||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||||
num_valid_tokens=topk_ids.numel(),
|
num_valid_tokens=num_tokens,
|
||||||
group_size=block_shape[1],
|
group_size=block_shape[1],
|
||||||
num_experts=B.shape[0],
|
num_experts=B.shape[0],
|
||||||
bit=4 if use_int4_w4a16 else 8)
|
bit=4 if use_int4_w4a16 else 8)
|
||||||
@ -790,12 +506,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
config.update(
|
config.update(
|
||||||
get_moe_wna16_block_config(config=config,
|
get_moe_wna16_block_config(config=config,
|
||||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||||
num_valid_tokens=topk_ids.numel(),
|
num_valid_tokens=num_tokens,
|
||||||
size_k=A.shape[1],
|
size_k=A.shape[1],
|
||||||
size_n=B.shape[1],
|
size_n=B.shape[1],
|
||||||
num_experts=B.shape[1],
|
num_experts=B.shape[1],
|
||||||
group_size=block_shape[1],
|
group_size=block_shape[1],
|
||||||
real_top_k=topk_ids.shape[1],
|
real_top_k=top_k,
|
||||||
block_size_m=config["BLOCK_SIZE_M"]))
|
block_size_m=config["BLOCK_SIZE_M"]))
|
||||||
|
|
||||||
if use_moe_wna16_cuda:
|
if use_moe_wna16_cuda:
|
||||||
@ -821,7 +537,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
B.shape[1],
|
B.shape[1],
|
||||||
A.shape[1],
|
A.shape[1],
|
||||||
EM,
|
EM,
|
||||||
topk_ids.numel(),
|
num_tokens,
|
||||||
A.stride(0),
|
A.stride(0),
|
||||||
A.stride(1),
|
A.stride(1),
|
||||||
B.stride(0),
|
B.stride(0),
|
||||||
@ -864,7 +580,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
B.shape[1],
|
B.shape[1],
|
||||||
B.shape[2],
|
B.shape[2],
|
||||||
EM,
|
EM,
|
||||||
topk_ids.numel(),
|
num_tokens,
|
||||||
A.stride(0),
|
A.stride(0),
|
||||||
A.stride(1),
|
A.stride(1),
|
||||||
B.stride(0),
|
B.stride(0),
|
||||||
@ -1389,6 +1105,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
w2=w2,
|
w2=w2,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
|
inplace=inplace,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
@ -1419,85 +1136,6 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
A permutation routine that works on fp8 types.
|
|
||||||
"""
|
|
||||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
|
||||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
|
||||||
else:
|
|
||||||
return m[idx, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_permute(
|
|
||||||
curr_hidden_states: torch.Tensor,
|
|
||||||
a1q_scale: Optional[torch.Tensor],
|
|
||||||
curr_topk_ids: torch.Tensor,
|
|
||||||
global_num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
top_k_num: int,
|
|
||||||
block_m: int,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
|
||||||
torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
|
||||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
|
||||||
"""
|
|
||||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
|
||||||
|
|
||||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
|
||||||
moe_align_block_size(curr_topk_ids,
|
|
||||||
block_m,
|
|
||||||
global_num_experts,
|
|
||||||
expert_map,
|
|
||||||
pad_sorted_ids=True))
|
|
||||||
|
|
||||||
inv_perm: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
num_tokens = top_k_num * tokens_in_chunk
|
|
||||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
|
||||||
expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0)
|
|
||||||
inv_perm = torch.argsort(sorted_token_ids)[:num_tokens]
|
|
||||||
|
|
||||||
# Permute according to sorted token ids.
|
|
||||||
curr_hidden_states = _fp8_perm(curr_hidden_states,
|
|
||||||
sorted_token_ids // top_k_num)
|
|
||||||
|
|
||||||
if a1q_scale is not None:
|
|
||||||
a1q_scale = a1q_scale[sorted_token_ids // top_k_num]
|
|
||||||
|
|
||||||
return (curr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
|
||||||
inv_perm)
|
|
||||||
|
|
||||||
|
|
||||||
def _moe_unpermute_and_reduce(
|
|
||||||
out: torch.Tensor,
|
|
||||||
curr_hidden: torch.Tensor,
|
|
||||||
inv_perm: Optional[torch.Tensor],
|
|
||||||
topk: int,
|
|
||||||
K: int,
|
|
||||||
topk_weight: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Unpermute the final result and apply topk_weights, then perform the final
|
|
||||||
reduction on the hidden states.
|
|
||||||
"""
|
|
||||||
M = topk_weight.shape[0]
|
|
||||||
curr_hidden = curr_hidden[inv_perm, ...]
|
|
||||||
curr_hidden = curr_hidden.view(-1, topk, K)
|
|
||||||
curr_hidden.mul_(topk_weight.view(M, -1, 1))
|
|
||||||
ops.moe_sum(curr_hidden, out)
|
|
||||||
|
|
||||||
|
|
||||||
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Shrink the given tensor and apply the given view to it. This is
|
|
||||||
used to resize the intermediate fused_moe caches.
|
|
||||||
"""
|
|
||||||
assert prod(v) <= x.numel()
|
|
||||||
return x.flatten()[:prod(v)].view(*v)
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -1629,7 +1267,6 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
w1_scale,
|
w1_scale,
|
||||||
w1_zp,
|
w1_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
curr_topk_ids,
|
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
@ -1660,28 +1297,34 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
qintermediate_cache2 = intermediate_cache2
|
qintermediate_cache2 = intermediate_cache2
|
||||||
a2q_scale = a2_scale
|
a2q_scale = a2_scale
|
||||||
|
|
||||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
invoke_fused_moe_kernel(
|
||||||
w2,
|
qintermediate_cache2,
|
||||||
intermediate_cache3,
|
w2,
|
||||||
a2q_scale,
|
intermediate_cache3,
|
||||||
w2_scale,
|
a2q_scale,
|
||||||
w2_zp,
|
w2_scale,
|
||||||
curr_topk_weights,
|
w2_zp,
|
||||||
curr_topk_ids,
|
curr_topk_weights,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
expert_ids,
|
expert_ids,
|
||||||
num_tokens_post_padded,
|
num_tokens_post_padded,
|
||||||
True,
|
False, #True,
|
||||||
1,
|
1,
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
use_int4_w4a16=use_int4_w4a16,
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
if True:
|
||||||
|
intermediate_cache3 = intermediate_cache3.view(-1, top_k_num, K)
|
||||||
|
intermediate_cache3.mul_(
|
||||||
|
curr_topk_weights.view(tokens_in_chunk, -1, 1))
|
||||||
|
|
||||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
out_hidden_states[begin_chunk_idx:end_chunk_idx])
|
||||||
|
|
||||||
return out_hidden_states
|
return out_hidden_states
|
||||||
|
|
||||||
|
|
||||||
@ -1790,327 +1433,3 @@ def fused_moe(
|
|||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
|
|
||||||
def deep_gemm_moe_fp8(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
inplace: bool = False,
|
|
||||||
activation: str = "silu",
|
|
||||||
global_num_experts: int = -1,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
|
||||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
|
||||||
mechanism. The matrix multiplications are implemented with DeepGemm
|
|
||||||
grouped gemm.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
|
||||||
Shape: [M, K]
|
|
||||||
- w1 (torch.Tensor): The first set of fp8 quantized expert weights.
|
|
||||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
|
||||||
- w2 (torch.Tensor): The second set of fp8 quantized expert weights.
|
|
||||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
|
||||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
|
||||||
Shape: [num_experts] or [num_experts, 2N]
|
|
||||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
|
||||||
Shape: [num_experts] or [num_experts, K]
|
|
||||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
||||||
- topk_ids (torch.Tensor): The token->expert mapping for topk_weights.
|
|
||||||
- inplace (bool): If True, perform the operation in-place.
|
|
||||||
Defaults to False.
|
|
||||||
- activation (str): The activation function to apply after the first
|
|
||||||
MoE layer.
|
|
||||||
- global_num_experts (int): The total number of experts in the global
|
|
||||||
expert space.
|
|
||||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
|
||||||
from the global expert space to the local expert space of the expert
|
|
||||||
parallel shard.
|
|
||||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
|
||||||
Shape: scalar or [M]
|
|
||||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
|
||||||
quantize the intermediate result between the gemms.
|
|
||||||
Shape: scalar or [M]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
|
||||||
"""
|
|
||||||
# Lazy import to avoid CUDA initialization problems.
|
|
||||||
import deep_gemm as dg
|
|
||||||
|
|
||||||
assert expert_map is None, "Expert maps not supported yet"
|
|
||||||
|
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
||||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
||||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
|
||||||
assert hidden_states.dtype in [
|
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
|
||||||
]
|
|
||||||
assert w1.dtype == torch.float8_e4m3fn
|
|
||||||
assert w2.dtype == torch.float8_e4m3fn
|
|
||||||
assert w1.shape[0] == w2.shape[0], "Expert number mismatch"
|
|
||||||
assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
|
||||||
assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
|
||||||
assert a1_scale is None or a1_scale.dim(
|
|
||||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[
|
|
||||||
0] == hidden_states.shape[0], "Input scale shape mismatch"
|
|
||||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
|
||||||
|
|
||||||
num_tokens, _ = hidden_states.shape
|
|
||||||
E, N, _ = w1.shape
|
|
||||||
K = w2.shape[1]
|
|
||||||
if global_num_experts == -1:
|
|
||||||
global_num_experts = E
|
|
||||||
top_k_num = topk_ids.shape[1]
|
|
||||||
# We execute the fused_moe kernel in chunks to circumvent this issue:
|
|
||||||
# https://github.com/vllm-project/vllm/issues/5938
|
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
||||||
|
|
||||||
assert _valid_deep_gemm(hidden_states, w1, w2, expert_map)
|
|
||||||
|
|
||||||
if inplace:
|
|
||||||
out_hidden_states = hidden_states
|
|
||||||
else:
|
|
||||||
out_hidden_states = torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
block_m = dg.get_m_alignment_for_contiguous_layout()
|
|
||||||
block_shape = [block_m, block_m]
|
|
||||||
|
|
||||||
assert w1_scale is not None
|
|
||||||
assert w2_scale is not None
|
|
||||||
|
|
||||||
# We attempt to transpose and align offline in Fp8MoEMethod, in which
|
|
||||||
# case these calls will be nops. Otherwise, they'll be performed every
|
|
||||||
# time the layer is executed.
|
|
||||||
w1_scale = dg.get_col_major_tma_aligned_tensor(w1_scale).contiguous()
|
|
||||||
w2_scale = dg.get_col_major_tma_aligned_tensor(w2_scale).contiguous()
|
|
||||||
|
|
||||||
M_sum = topk_ids.numel() + global_num_experts * (block_m - 1)
|
|
||||||
M_sum = round_up(M_sum, block_m)
|
|
||||||
|
|
||||||
num_chunks = (num_tokens // CHUNK_SIZE) + 1
|
|
||||||
|
|
||||||
# We can reuse the memory between cache1 and cache3 because by the time
|
|
||||||
# we need cache3, we're done with cache1
|
|
||||||
cache13 = torch.empty(M_sum * max(N, K),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
intermediate_cache1 = cache13[:M_sum * N].view(M_sum, N)
|
|
||||||
intermediate_cache2 = torch.empty((M_sum, N // 2),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
intermediate_cache3 = cache13[:M_sum * K].view(M_sum, K)
|
|
||||||
|
|
||||||
for chunk in range(num_chunks):
|
|
||||||
begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
|
|
||||||
min((chunk + 1) * CHUNK_SIZE,
|
|
||||||
num_tokens))
|
|
||||||
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
|
|
||||||
tokens_in_chunk, _ = curr_hidden_states.shape
|
|
||||||
|
|
||||||
if tokens_in_chunk == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
|
||||||
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
|
|
||||||
|
|
||||||
a1q_scale: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
qcurr_hidden_states, a1q_scale = _fp8_quantize(curr_hidden_states,
|
|
||||||
a1_scale, block_shape)
|
|
||||||
|
|
||||||
(qcurr_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
|
|
||||||
inv_perm) = _moe_permute(qcurr_hidden_states, a1q_scale,
|
|
||||||
curr_topk_ids, global_num_experts,
|
|
||||||
expert_map, top_k_num, block_m)
|
|
||||||
|
|
||||||
# Adjust the intermediate cache size and config for the last chunk.
|
|
||||||
# Note that in most cases we only have one chunk so the cache size
|
|
||||||
# and config are already set correctly and do not need to be adjusted.
|
|
||||||
if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
|
|
||||||
curr_M = sorted_token_ids.numel()
|
|
||||||
intermediate_cache1 = _resize_cache(intermediate_cache1,
|
|
||||||
(curr_M, N))
|
|
||||||
intermediate_cache2 = _resize_cache(intermediate_cache2,
|
|
||||||
(curr_M, N // 2))
|
|
||||||
intermediate_cache3 = _resize_cache(intermediate_cache3,
|
|
||||||
(curr_M, K))
|
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
||||||
(qcurr_hidden_states, a1q_scale), (w1, w1_scale),
|
|
||||||
intermediate_cache1, expert_ids)
|
|
||||||
|
|
||||||
if activation == "silu":
|
|
||||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
|
||||||
intermediate_cache1.view(-1, N))
|
|
||||||
elif activation == "gelu":
|
|
||||||
torch.ops._C.gelu_and_mul(intermediate_cache2,
|
|
||||||
intermediate_cache1.view(-1, N))
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
|
||||||
|
|
||||||
a2q_scale: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
qintermediate_cache2, a2q_scale = _fp8_quantize(
|
|
||||||
intermediate_cache2, a2_scale, block_shape)
|
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
||||||
(qintermediate_cache2, a2q_scale), (w2, w2_scale),
|
|
||||||
intermediate_cache3, expert_ids)
|
|
||||||
|
|
||||||
_moe_unpermute_and_reduce(
|
|
||||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
|
||||||
intermediate_cache3.view(*intermediate_cache3.shape), inv_perm,
|
|
||||||
top_k_num, K, curr_topk_weights)
|
|
||||||
|
|
||||||
return out_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
#TODO make the grouped gemm kernel consistent with scaled gemm kernel
|
|
||||||
def cutlass_moe_fp8(
|
|
||||||
a: torch.Tensor,
|
|
||||||
w1_q: torch.Tensor,
|
|
||||||
w2_q: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
ab_strides1: torch.Tensor,
|
|
||||||
c_strides1: torch.Tensor,
|
|
||||||
ab_strides2: torch.Tensor,
|
|
||||||
c_strides2: torch.Tensor,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
out_dtype: torch.dtype = torch.half,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
|
|
||||||
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
|
||||||
mechanism. The matrix multiplications are implemented with CUTLASS
|
|
||||||
grouped gemm.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- a (torch.Tensor): The input tensor to the MoE layer.
|
|
||||||
Shape: [M, K]
|
|
||||||
- w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
|
|
||||||
Shape: [num_experts, K, 2N] (the weights are passed transposed)
|
|
||||||
- w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
|
|
||||||
Shape: [num_experts, N, K] (the weights are passed transposed)
|
|
||||||
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
|
||||||
Shape: [num_experts] or [num_experts, 2N]
|
|
||||||
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
|
||||||
Shape: [num_experts] or [num_experts, K]
|
|
||||||
- gating_output (torch.Tensor): The output of the gating operation
|
|
||||||
(before softmax).
|
|
||||||
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
|
||||||
- ab_strides1 (torch.Tensor): The input and weights strides of the first
|
|
||||||
grouped gemm.
|
|
||||||
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
|
||||||
- ab_strides2 (torch.Tensor): The input and weights strides of the second
|
|
||||||
grouped gemm.
|
|
||||||
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
|
||||||
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
|
||||||
Shape: scalar or [M]
|
|
||||||
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
|
||||||
quantize the intermediate result between the gemms.
|
|
||||||
Shape: scalar or [M]
|
|
||||||
- out_dtype (torch.Tensor): The output tensor type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
||||||
assert w1_q.dtype == torch.float8_e4m3fn
|
|
||||||
assert w2_q.dtype == torch.float8_e4m3fn
|
|
||||||
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
|
|
||||||
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
|
|
||||||
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
|
||||||
assert a1_scale is None or a1_scale.dim(
|
|
||||||
) == 0 or a1_scale.shape[0] == 1 or a1_scale.shape[0] == a.shape[
|
|
||||||
0], "Input scale shape mismatch"
|
|
||||||
assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[
|
|
||||||
1] == w1_q.shape[2], "W1 scale shape mismatch"
|
|
||||||
assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[
|
|
||||||
1] == w2_q.shape[2], "W2 scale shape mismatch"
|
|
||||||
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
|
|
||||||
assert w1_q.shape[0] == w1_scale.shape[
|
|
||||||
0], "w1 scales expert number mismatch"
|
|
||||||
assert w1_q.shape[0] == w2_scale.shape[
|
|
||||||
0], "w2 scales expert number mismatch"
|
|
||||||
assert a2_scale is None or a1_scale is None or a2_scale.shape == a1_scale.shape, "Intermediate scale shape mismatch" # noqa: E501
|
|
||||||
assert ab_strides1.shape[0] == w1_q.shape[
|
|
||||||
0], "AB Strides 1 expert number mismatch"
|
|
||||||
assert c_strides1.shape[0] == w1_q.shape[
|
|
||||||
0], "C Strides 1 expert number mismatch"
|
|
||||||
assert ab_strides2.shape[0] == w2_q.shape[
|
|
||||||
0], "AB Strides 2 expert number mismatch"
|
|
||||||
assert c_strides2.shape[0] == w2_q.shape[
|
|
||||||
0], "C Strides 2 expert number mismatch"
|
|
||||||
assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
|
|
||||||
|
|
||||||
num_experts = w1_q.size(0)
|
|
||||||
m = a.size(0)
|
|
||||||
k = w1_q.size(1)
|
|
||||||
n = w2_q.size(1)
|
|
||||||
|
|
||||||
topk = topk_ids.size(1)
|
|
||||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
|
||||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
|
||||||
|
|
||||||
a_q, a1_scale = ops.scaled_fp8_quant(
|
|
||||||
a, a1_scale, use_per_token_if_dynamic=per_act_token)
|
|
||||||
device = a_q.device
|
|
||||||
|
|
||||||
expert_offsets = torch.empty((num_experts + 1),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
problem_sizes1 = torch.empty((num_experts, 3),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
problem_sizes2 = torch.empty((num_experts, 3),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
|
||||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1,
|
|
||||||
problem_sizes2, a_map, c_map, num_experts, n,
|
|
||||||
k)
|
|
||||||
|
|
||||||
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
|
|
||||||
rep_a1_scales = a1_scale[a_map] if per_act_token else a1_scale
|
|
||||||
|
|
||||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
|
|
||||||
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
|
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c1, rep_a_q, w1_q, rep_a1_scales, w1_scale,
|
|
||||||
expert_offsets[:-1], problem_sizes1, ab_strides1,
|
|
||||||
ab_strides1, c_strides1)
|
|
||||||
|
|
||||||
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
|
|
||||||
torch.ops._C.silu_and_mul(intermediate, c1)
|
|
||||||
|
|
||||||
intemediate_q, a2_scale = ops.scaled_fp8_quant(
|
|
||||||
intermediate, a2_scale, use_per_token_if_dynamic=per_act_token)
|
|
||||||
|
|
||||||
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
|
|
||||||
expert_offsets[:-1], problem_sizes2, ab_strides2,
|
|
||||||
ab_strides2, c_strides2)
|
|
||||||
|
|
||||||
return (c2[c_map].view(m, topk, k) *
|
|
||||||
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
|
||||||
|
|||||||
243
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
243
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.utils import round_up
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_div(a, b):
|
||||||
|
return (a + b - 1) // b
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage1(
|
||||||
|
topk_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
|
||||||
|
off_c = (pid + 1) * num_experts
|
||||||
|
|
||||||
|
for i in range(tokens_per_thread):
|
||||||
|
if start_idx + i < numel:
|
||||||
|
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||||
|
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage2(
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
|
||||||
|
last_cnt = 0
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||||
|
last_cnt = last_cnt + token_cnt
|
||||||
|
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage3(
|
||||||
|
total_tokens_post_pad_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
last_cumsum = 0
|
||||||
|
off_cnt = num_experts * num_experts
|
||||||
|
for i in range(1, num_experts + 1):
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||||
|
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||||
|
tl.store(cumsum_ptr + i, last_cumsum)
|
||||||
|
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def moe_align_block_size_stage4(
|
||||||
|
topk_ids_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
tokens_cnts_ptr,
|
||||||
|
cumsum_ptr,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
numel: tl.constexpr,
|
||||||
|
tokens_per_thread: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
start_idx = tl.load(cumsum_ptr + pid)
|
||||||
|
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||||
|
|
||||||
|
for i in range(start_idx, end_idx, block_size):
|
||||||
|
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||||
|
|
||||||
|
start_idx = pid * tokens_per_thread
|
||||||
|
off_t = pid * num_experts
|
||||||
|
|
||||||
|
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||||
|
numel)):
|
||||||
|
expert_id = tl.load(topk_ids_ptr + i)
|
||||||
|
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||||
|
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||||
|
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||||
|
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||||
|
|
||||||
|
|
||||||
|
# Triton implementation based on:
|
||||||
|
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||||
|
def moe_align_block_size_triton(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
block_size: int,
|
||||||
|
sorted_token_ids: torch.Tensor,
|
||||||
|
expert_ids: torch.Tensor,
|
||||||
|
num_tokens_post_pad: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
numel = topk_ids.numel()
|
||||||
|
grid = (num_experts, )
|
||||||
|
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
cumsum = torch.zeros((num_experts + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
tokens_per_thread = ceil_div(numel, num_experts)
|
||||||
|
|
||||||
|
moe_align_block_size_stage1[grid](
|
||||||
|
topk_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage2[grid](
|
||||||
|
tokens_cnts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage3[(1, )](
|
||||||
|
num_tokens_post_pad,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
)
|
||||||
|
moe_align_block_size_stage4[grid](
|
||||||
|
topk_ids,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
tokens_cnts,
|
||||||
|
cumsum,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
numel,
|
||||||
|
tokens_per_thread,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def moe_align_block_size(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
pad_sorted_ids: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Aligns the token distribution across experts to be compatible with block
|
||||||
|
size for matrix multiplication.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
|
||||||
|
top-k expert indices for each token.
|
||||||
|
- block_size: The block size used in block matrix multiplication.
|
||||||
|
- num_experts: The total number of experts.
|
||||||
|
- expert_map: A tensor of shape [num_experts] that maps the expert index
|
||||||
|
from the global space to the local index space of the current
|
||||||
|
expert parallel shard. If the expert is not in the current expert
|
||||||
|
parallel shard, the mapping is set to -1.
|
||||||
|
- pad_sorted_ids: A flag indicating whether the sorted_token_ids length
|
||||||
|
should be padded to a multiple of block_size,
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- sorted_token_ids: A tensor containing the sorted token indices according
|
||||||
|
to their allocated expert.
|
||||||
|
- expert_ids: A tensor indicating the assigned expert index for each block.
|
||||||
|
- num_tokens_post_padded: The total number of tokens after padding,
|
||||||
|
ensuring divisibility by block_size.
|
||||||
|
|
||||||
|
This function pads the number of tokens that each expert needs to process
|
||||||
|
so that it is divisible by block_size.
|
||||||
|
Padding ensures that during block matrix multiplication, the dimensions
|
||||||
|
align correctly.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
|
||||||
|
block_size = 4, and num_experts = 4:
|
||||||
|
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
|
||||||
|
with each expert needing to process 3 tokens.
|
||||||
|
- As block_size is 4, we pad 1 token for each expert.
|
||||||
|
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
|
||||||
|
- Then append padding tokens [12, 12, 12, 12] for each block.
|
||||||
|
- After sorting by expert index, we obtain token_ids
|
||||||
|
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
|
||||||
|
Tokens 12 are non-existent (padding) and are ignored in
|
||||||
|
the subsequent matrix multiplication.
|
||||||
|
- The padding ensures that the total number of tokens is now divisible
|
||||||
|
by block_size for proper block matrix operations.
|
||||||
|
"""
|
||||||
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
|
if pad_sorted_ids:
|
||||||
|
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||||
|
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
|
# Expert ids must be zeroed out to prevent index out of bounds error while
|
||||||
|
# mapping global expert ids to local expert ids in expert parallelism.
|
||||||
|
expert_ids = torch.zeros((max_num_m_blocks, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
num_tokens_post_pad = torch.empty((1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device)
|
||||||
|
if num_experts >= 224:
|
||||||
|
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
|
||||||
|
moe_align_block_size_triton(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Currently requires num_experts=256
|
||||||
|
ops.sgl_moe_align_block_size(
|
||||||
|
topk_ids,
|
||||||
|
num_experts,
|
||||||
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_pad,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||||
|
expert_ids, num_tokens_post_pad)
|
||||||
|
if expert_map is not None:
|
||||||
|
expert_ids = expert_map[expert_ids]
|
||||||
|
|
||||||
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
48
vllm/model_executor/layers/fused_moe/utils.py
Normal file
48
vllm/model_executor/layers/fused_moe/utils.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from math import prod
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8)
|
||||||
|
from vllm.utils import cdiv
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_cache(x: torch.Tensor, v: Tuple[int, ...]) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Shrink the given tensor and apply the given view to it. This is
|
||||||
|
used to resize the intermediate fused_moe caches.
|
||||||
|
"""
|
||||||
|
assert prod(v) <= x.numel()
|
||||||
|
return x.flatten()[:prod(v)].view(*v)
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_quantize(
|
||||||
|
A: torch.Tensor,
|
||||||
|
A_scale: Optional[torch.Tensor],
|
||||||
|
block_shape: Optional[List[int]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Perform fp8 quantization on the inputs. If a block_shape
|
||||||
|
is provided, the output will be blocked.
|
||||||
|
"""
|
||||||
|
if block_shape is None:
|
||||||
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
|
else:
|
||||||
|
assert len(block_shape) == 2
|
||||||
|
_, block_k = block_shape[0], block_shape[1]
|
||||||
|
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||||
|
assert cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||||
|
return A, A_scale
|
||||||
|
|
||||||
|
|
||||||
|
def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
A permutation routine that works on fp8 types.
|
||||||
|
"""
|
||||||
|
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||||
|
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||||
|
else:
|
||||||
|
return m[idx, ...]
|
||||||
Loading…
x
Reference in New Issue
Block a user