mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 16:57:09 +08:00
stuff
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
f8510587c2
commit
909f234faa
@ -1,30 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton.language as tl
|
||||
from typing import Optional
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel,
|
||||
BatchedExperts,
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||
get_default_config)
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts,
|
||||
invoke_moe_batched_triton_kernel)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
@ -80,10 +76,12 @@ class BatchedMMTensors:
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||
As: torch.Tensor, Bs: torch.Tensor,
|
||||
def native_w8a8_block_matmul(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size,
|
||||
output_dtype = torch.bfloat16):
|
||||
output_dtype=torch.bfloat16):
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization using native torch.
|
||||
It is agnostic to the input data type and can be used for both int8 and
|
||||
@ -160,16 +158,11 @@ def ref_impl(
|
||||
if A.dtype == torch.torch.float8_e4m3fn:
|
||||
if False:
|
||||
tmp = native_w8a8_block_matmul(A[e, :, :],
|
||||
B[e].transpose(0, 1),
|
||||
A_scale,
|
||||
B_scale,
|
||||
block_shape)
|
||||
B[e].transpose(0, 1), A_scale,
|
||||
B_scale, block_shape)
|
||||
else:
|
||||
tmp = ops.cutlass_scaled_mm(A[e, :, :],
|
||||
B[e].transpose(0, 1),
|
||||
A_scale,
|
||||
B_scale,
|
||||
torch.bfloat16)
|
||||
tmp = ops.cutlass_scaled_mm(A[e, :, :], B[e].transpose(0, 1),
|
||||
A_scale, B_scale, torch.bfloat16)
|
||||
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||
else:
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
@ -195,7 +188,8 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
in_dtype = dtype
|
||||
out_dtype = dtype
|
||||
|
||||
config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N)
|
||||
config = BatchedMMConfig(in_dtype, out_dtype, num_experts,
|
||||
max_tokens_per_expert, K, N)
|
||||
tensors = BatchedMMTensors.make_tensors(config)
|
||||
|
||||
test_output = tensors.C
|
||||
@ -209,7 +203,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
}[test_output.dtype]
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||
block_shape = [16, 16, 32] # 16 for k if not fp8
|
||||
block_shape = [16, 16, 32] # 16 for k if not fp8
|
||||
|
||||
#print(f"tensors.A {tensors.A.shape}")
|
||||
#print(f"tensors.B {tensors.B.shape}")
|
||||
@ -250,19 +244,12 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
|
||||
ref_output = ref_output.to(dtype=out_dtype)
|
||||
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
|
||||
tensors.B.to(dtype=out_dtype),
|
||||
ref_output,
|
||||
tensors.num_expert_tokens,
|
||||
A_scale,
|
||||
B_scale,
|
||||
tensors.B.to(dtype=out_dtype), ref_output,
|
||||
tensors.num_expert_tokens, A_scale, B_scale,
|
||||
block_shape[-2:])
|
||||
|
||||
ref_output2 = ref_impl(tensors.A,
|
||||
tensors.B,
|
||||
ref_output2,
|
||||
tensors.num_expert_tokens,
|
||||
A_scale,
|
||||
B_scale,
|
||||
ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2,
|
||||
tensors.num_expert_tokens, A_scale, B_scale,
|
||||
block_shape[-2:])
|
||||
|
||||
rtol, atol = {
|
||||
@ -286,11 +273,17 @@ def batched_moe(
|
||||
use_fp8_w8a8: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64) # ?
|
||||
max_num_tokens = round_up(a.shape[0], 64) # ?
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens, world_size=1, dp_size=1, rank=0, use_fp8_w8a8=use_fp8_w8a8,
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
block_shape=block_shape),
|
||||
BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1,
|
||||
BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||
dp_size=1,
|
||||
world_size=1,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
block_shape=block_shape))
|
||||
|
||||
@ -322,11 +315,13 @@ def torch_moe2(
|
||||
|
||||
if use_fp8_w8a8:
|
||||
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
|
||||
#print(f"a_scale {a_scale.shape}")
|
||||
else:
|
||||
a_scale = None
|
||||
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device)
|
||||
out = torch.zeros(M * topk,
|
||||
w2.shape[1],
|
||||
dtype=torch.bfloat16,
|
||||
device=a.device)
|
||||
num_experts = w1.shape[0]
|
||||
for i in range(num_experts):
|
||||
mask = (topk_ids == i).view(-1)
|
||||
@ -341,11 +336,8 @@ def torch_moe2(
|
||||
# a_scale[mask],
|
||||
# w1_scale[i],
|
||||
# torch.bfloat16)
|
||||
tmp1 = native_w8a8_block_matmul(a[mask],
|
||||
w1[i],
|
||||
a_scale[mask],
|
||||
w1_scale[i],
|
||||
block_shape,
|
||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||
w1_scale[i], block_shape,
|
||||
torch.bfloat16)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
||||
@ -355,11 +347,8 @@ def torch_moe2(
|
||||
# b_scale,
|
||||
# w2_scale[i],
|
||||
# torch.bfloat16)
|
||||
out[mask] = native_w8a8_block_matmul(tmp2,
|
||||
w2[i],
|
||||
b_scale,
|
||||
w2_scale[i],
|
||||
block_shape,
|
||||
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||
w2_scale[i], block_shape,
|
||||
torch.bfloat16)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
@ -406,23 +395,21 @@ def test_fused_moe_batched_experts(
|
||||
|
||||
factor_for_scale = 1e-2
|
||||
w1_s = torch.rand(
|
||||
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
|
||||
device="cuda") * factor_for_scale
|
||||
w2_s = torch.rand(
|
||||
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
|
||||
device="cuda") * factor_for_scale
|
||||
else:
|
||||
w1_s = None
|
||||
w2_s = None
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
|
||||
# batched_output = batched_moe(a,
|
||||
# w1.to(torch.bfloat16),
|
||||
# w2.to(torch.bfloat16),
|
||||
# topk_weight, topk_ids,
|
||||
# w1_s, w2_s, False,
|
||||
# block_shape)
|
||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||
w2_s, use_fp8_w8a8, block_shape)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
|
||||
w2_s, use_fp8_w8a8, block_shape)
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
batched_output,
|
||||
|
||||
@ -9,47 +9,44 @@ import triton.language as tl
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache,
|
||||
cdiv)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_mmk(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_w8a8: tl.constexpr,
|
||||
use_w8a16: tl.constexpr
|
||||
):
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_w8a8: tl.constexpr,
|
||||
use_w8a16: tl.constexpr):
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
if use_w8a16:
|
||||
@ -313,22 +310,21 @@ def batched_triton_kernel(
|
||||
|
||||
|
||||
def invoke_moe_batched_triton_kernel(
|
||||
A: torch.Tensor, # [E, max_tokens, K]
|
||||
B: torch.Tensor, # [E, K, N]
|
||||
C: torch.Tensor, # [E, max_tokens, N]
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
block_shape: Optional[list[int]] = None
|
||||
):
|
||||
A: torch.Tensor, # [E, max_tokens, K]
|
||||
B: torch.Tensor, # [E, K, N]
|
||||
C: torch.Tensor, # [E, max_tokens, N]
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
block_shape: Optional[list[int]] = None):
|
||||
assert not use_int4_w4a16
|
||||
max_num_tokens = A.size(1)
|
||||
K = A.size(2)
|
||||
@ -392,8 +388,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
that the PPLX dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_tokens: Optional[int], world_size: int,
|
||||
dp_size: int, rank: int, use_fp8_w8a8: bool = False,
|
||||
def __init__(self,
|
||||
max_num_tokens: Optional[int],
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
rank: int,
|
||||
use_fp8_w8a8: bool = False,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
super().__init__()
|
||||
self.world_size = world_size
|
||||
@ -463,13 +463,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks.flatten())
|
||||
|
||||
rhs = a1[:topks.numel()][topks]
|
||||
idx = expert_id - first_expert
|
||||
if self.use_fp8_w8a8:
|
||||
# TODO: use _fp8_quantize
|
||||
b_a1[idx, :rows, :], tmp_scale = per_token_group_quant_fp8(rhs, block_k)
|
||||
b_a1_scale[idx, :rows] = tmp_scale # inline?
|
||||
b_a1[idx, :rows, :], b_a1_scale[
|
||||
idx, :rows] = per_token_group_quant_fp8(rhs, block_k)
|
||||
else:
|
||||
b_a1[idx, :rows, :] = rhs
|
||||
|
||||
@ -549,7 +548,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * K
|
||||
workspace2 = max_num_tokens * num_dp * N
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
@ -607,9 +605,10 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
num = int(expert_num_tokens[expert].item())
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
if self.use_fp8_w8a8:
|
||||
assert False # TBD
|
||||
assert False # TBD
|
||||
else:
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(
|
||||
0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
|
||||
@ -768,12 +767,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
#assert not self.use_fp8_w8a8
|
||||
if self.use_fp8_w8a8:
|
||||
per_act_token = False
|
||||
qintermediate_cache2 = torch.zeros_like(intermediate_cache2,
|
||||
# TODO: reuse?
|
||||
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
block_n = self.block_shape[0]
|
||||
n_tiles = ((N // 2) + block_n - 1) // block_n
|
||||
scale_shape = (E, num_tokens, n_tiles)
|
||||
a2q_scale = torch.zeros(scale_shape,
|
||||
a2q_scale = torch.empty(scale_shape,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
for e in range(E):
|
||||
@ -783,10 +783,10 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# intermediate_cache2[e],
|
||||
# a2_scale[e] if a2_scale is not None else None,
|
||||
# per_act_token, self.block_shape)
|
||||
qintermediate_cache2[e, :num_tokens, :], tmp_scale = per_token_group_quant_fp8(
|
||||
intermediate_cache2[e, :num_tokens], block_n)
|
||||
#print(a2q_scale[e, :tmp_scale.shape[0]].shape)
|
||||
#print(tmp_scale.shape)
|
||||
qintermediate_cache2[
|
||||
e, :
|
||||
num_tokens, :], tmp_scale = per_token_group_quant_fp8(
|
||||
intermediate_cache2[e, :num_tokens], block_n)
|
||||
a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
|
||||
@ -1240,8 +1240,6 @@ class FusedMoE(torch.nn.Module):
|
||||
if indices_type is not None:
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
def must_reduce_shared_expert_outputs(self) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user