mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 08:27:02 +08:00
tests + fix
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
9cfebf51ba
commit
f8510587c2
@ -7,8 +7,30 @@ import torch
|
||||
import triton.language as tl
|
||||
from typing import Optional
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel)
|
||||
invoke_moe_batched_triton_kernel,
|
||||
BatchedExperts,
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||
get_default_config)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -141,14 +163,13 @@ def ref_impl(
|
||||
B[e].transpose(0, 1),
|
||||
A_scale,
|
||||
B_scale,
|
||||
[1,1])#block_shape)
|
||||
block_shape)
|
||||
else:
|
||||
import vllm._custom_ops as ops
|
||||
tmp = ops.cutlass_scaled_mm(A[e, :, :],
|
||||
B[e].transpose(0, 1),
|
||||
A_scale,
|
||||
B_scale,
|
||||
C.dtype)
|
||||
torch.bfloat16)
|
||||
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||
else:
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
@ -194,8 +215,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
#print(f"tensors.B {tensors.B.shape}")
|
||||
|
||||
if use_fp8_w8a8:
|
||||
#A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
|
||||
#A_scale = torch.ones((1, K), dtype=torch.float32, device=tensors.A.device)
|
||||
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
||||
#quant_block_shape = [N, K]
|
||||
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
||||
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
||||
quant_block_shape = [1, 1]
|
||||
@ -251,3 +273,158 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
|
||||
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64) # ?
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens, world_size=1, dp_size=1, rank=0, use_fp8_w8a8=use_fp8_w8a8,
|
||||
block_shape=block_shape),
|
||||
BatchedTritonExperts(max_num_tokens=max_num_tokens, dp_size=1, world_size=1,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
block_shape=block_shape))
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
|
||||
# Note: same as torch_moe but with fused_topk factored out.
|
||||
def torch_moe2(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
|
||||
if use_fp8_w8a8:
|
||||
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
|
||||
#print(f"a_scale {a_scale.shape}")
|
||||
else:
|
||||
a_scale = None
|
||||
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=torch.bfloat16, device=a.device)
|
||||
num_experts = w1.shape[0]
|
||||
for i in range(num_experts):
|
||||
mask = (topk_ids == i).view(-1)
|
||||
if mask.sum():
|
||||
if not use_fp8_w8a8:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
else:
|
||||
#tmp1 = ops.cutlass_scaled_mm(a[mask],
|
||||
# w1[i].transpose(0, 1),
|
||||
# a_scale[mask],
|
||||
# w1_scale[i],
|
||||
# torch.bfloat16)
|
||||
tmp1 = native_w8a8_block_matmul(a[mask],
|
||||
w1[i],
|
||||
a_scale[mask],
|
||||
w1_scale[i],
|
||||
block_shape,
|
||||
torch.bfloat16)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
|
||||
|
||||
# out[mask] = ops.cutlass_scaled_mm(tmp2,
|
||||
# w2[i].transpose(0, 1),
|
||||
# b_scale,
|
||||
# w2_scale[i],
|
||||
# torch.bfloat16)
|
||||
out[mask] = native_w8a8_block_matmul(tmp2,
|
||||
w2[i],
|
||||
b_scale,
|
||||
w2_scale[i],
|
||||
block_shape,
|
||||
torch.bfloat16)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.torch.float8_e4m3fn, torch.bfloat16])
|
||||
def test_fused_moe_batched_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
block_shape = [128, 128]
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||
|
||||
if use_fp8_w8a8:
|
||||
block_n, block_k = block_shape[0], block_shape[1]
|
||||
n_tiles_w1 = (2 * n + block_n - 1) // block_n
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
|
||||
factor_for_scale = 1e-2
|
||||
w1_s = torch.rand(
|
||||
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
w2_s = torch.rand(
|
||||
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device="cuda") * factor_for_scale
|
||||
else:
|
||||
w1_s = None
|
||||
w2_s = None
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
|
||||
# batched_output = batched_moe(a,
|
||||
# w1.to(torch.bfloat16),
|
||||
# w2.to(torch.bfloat16),
|
||||
# topk_weight, topk_ids,
|
||||
# w1_s, w2_s, False,
|
||||
# block_shape)
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
|
||||
@ -9,8 +9,11 @@ import triton.language as tl
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
_resize_cache)
|
||||
_resize_cache,
|
||||
cdiv)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@ -390,12 +393,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_tokens: Optional[int], world_size: int,
|
||||
dp_size: int, rank: int):
|
||||
dp_size: int, rank: int, use_fp8_w8a8: bool = False,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
super().__init__()
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.rank = rank
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.block_shape = block_shape
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@ -419,6 +425,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
_, block_k = self.block_shape
|
||||
|
||||
num_tokens, hidden_dim = a1.size()
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@ -437,20 +445,37 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
b_a1 = torch.zeros(
|
||||
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||
dtype=a1.dtype,
|
||||
dtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else a1.dtype,
|
||||
device=a1.device)
|
||||
|
||||
if self.use_fp8_w8a8:
|
||||
k_tiles = (hidden_dim + block_k - 1) // block_k
|
||||
b_a1_scale = torch.zeros(
|
||||
(num_local_experts, self.max_num_tokens, k_tiles),
|
||||
dtype=torch.float32,
|
||||
device=a1.device)
|
||||
else:
|
||||
b_a1_scale = None
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks.flatten())
|
||||
b_a1[expert_id -
|
||||
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
||||
tokens_per_expert[expert_id - first_expert] = rows
|
||||
|
||||
return b_a1, a1_scale, tokens_per_expert
|
||||
rhs = a1[:topks.numel()][topks]
|
||||
idx = expert_id - first_expert
|
||||
if self.use_fp8_w8a8:
|
||||
# TODO: use _fp8_quantize
|
||||
b_a1[idx, :rows, :], tmp_scale = per_token_group_quant_fp8(rhs, block_k)
|
||||
b_a1_scale[idx, :rows] = tmp_scale # inline?
|
||||
else:
|
||||
b_a1[idx, :rows, :] = rhs
|
||||
|
||||
tokens_per_expert[idx] = rows
|
||||
|
||||
return b_a1, b_a1_scale, tokens_per_expert
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
@ -529,66 +554,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2 = max_num_tokens * num_dp * N
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
|
||||
def native_w8a8_block_matmul(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor):
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization using native torch.
|
||||
It is agnostic to the input data type and can be used for both int8 and
|
||||
fp8 data types.
|
||||
|
||||
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||
`Bs` (float32).
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert self.block_shape is not None and len(self.block_shape) == 2
|
||||
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N, )
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [
|
||||
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||
]
|
||||
B_tiles = [[
|
||||
B[
|
||||
j * block_n:min((j + 1) * block_n, N),
|
||||
i * block_k:min((i + 1) * block_k, K),
|
||||
] for i in range(k_tiles)
|
||||
] for j in range(n_tiles)]
|
||||
C_tiles = [
|
||||
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||
]
|
||||
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -643,9 +608,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
if self.use_fp8_w8a8:
|
||||
assert False # TBD
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
else:
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
@ -778,6 +740,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
(E, num_tokens, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K))
|
||||
|
||||
assert not self.use_fp8_w8a8 or a1q_scale is not None
|
||||
|
||||
# MM1
|
||||
invoke_moe_batched_triton_kernel(A=hidden_states,
|
||||
B=w1,
|
||||
@ -804,20 +768,26 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
#assert not self.use_fp8_w8a8
|
||||
if self.use_fp8_w8a8:
|
||||
per_act_token = False
|
||||
qintermediate_cache2 = torch.empty_like(intermediate_cache2,
|
||||
qintermediate_cache2 = torch.zeros_like(intermediate_cache2,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
if per_act_token:
|
||||
scale_shape = (E, num_tokens, 1)
|
||||
else:
|
||||
scale_shape = (E, 1)
|
||||
a2q_scale = torch.empty(scale_shape,
|
||||
block_n = self.block_shape[0]
|
||||
n_tiles = ((N // 2) + block_n - 1) // block_n
|
||||
scale_shape = (E, num_tokens, n_tiles)
|
||||
a2q_scale = torch.zeros(scale_shape,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
for e in range(E):
|
||||
qintermediate_cache2[e], a2q_scale[e] = _fp8_quantize(
|
||||
intermediate_cache2[e, :expert_num_tokens[e]],
|
||||
a2_scale[e] if a2_scale is not None else None,
|
||||
per_act_token, self.block_shape)
|
||||
num_tokens = expert_num_tokens[e]
|
||||
if num_tokens > 0:
|
||||
#qintermediate_cache2[e], tmp_scale = _fp8_quantize(
|
||||
# intermediate_cache2[e],
|
||||
# a2_scale[e] if a2_scale is not None else None,
|
||||
# per_act_token, self.block_shape)
|
||||
qintermediate_cache2[e, :num_tokens, :], tmp_scale = per_token_group_quant_fp8(
|
||||
intermediate_cache2[e, :num_tokens], block_n)
|
||||
#print(a2q_scale[e, :tmp_scale.shape[0]].shape)
|
||||
#print(tmp_scale.shape)
|
||||
a2q_scale[e, :tmp_scale.shape[0]] = tmp_scale
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user