mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 06:37:03 +08:00
wip
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
e568e401da
commit
25ed6738d4
@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
import torch
|
||||
import triton.language as tl
|
||||
from typing import Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel)
|
||||
@ -57,20 +58,94 @@ class BatchedMMTensors:
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
||||
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||
As: torch.Tensor, Bs: torch.Tensor,
|
||||
block_size,
|
||||
output_dtype = torch.bfloat16):
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization using native torch.
|
||||
It is agnostic to the input data type and can be used for both int8 and
|
||||
fp8 data types.
|
||||
|
||||
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||
`Bs` (float32).
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32).contiguous()
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
|
||||
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
|
||||
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N, )
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [
|
||||
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||
]
|
||||
B_tiles = [[
|
||||
B[
|
||||
j * block_n:min((j + 1) * block_n, N),
|
||||
i * block_k:min((i + 1) * block_k, K),
|
||||
] for i in range(k_tiles)
|
||||
] for j in range(n_tiles)]
|
||||
C_tiles = [
|
||||
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||
]
|
||||
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
def ref_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
) -> torch.Tensor:
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||
num_experts = num_expert_tokens.size(0)
|
||||
|
||||
for e in range(num_experts):
|
||||
num_tokens = num_expert_tokens_cpu[e]
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
if A.dtype == torch.torch.float8_e4m3fn:
|
||||
C[e, :, :] = native_w8a8_block_matmul(A[e, :, :],
|
||||
B[e].transpose(0, 1),
|
||||
A_scale,
|
||||
B_scale,
|
||||
[1,1])#block_shape)
|
||||
else:
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [16, 32])
|
||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||
[32, 64, 128, 192, 224, 256, 512])
|
||||
@ -94,6 +169,20 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32
|
||||
}[test_output.dtype]
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||
block_shape = [16, 16, 32] # 16 for k if not fp8
|
||||
|
||||
print(f"tensors.A {tensors.A.shape}")
|
||||
print(f"tensors.B {tensors.B.shape}")
|
||||
|
||||
if use_fp8_w8a8:
|
||||
A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
|
||||
B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
||||
else:
|
||||
A_scale = None
|
||||
B_scale = None
|
||||
|
||||
invoke_moe_batched_triton_kernel(
|
||||
tensors.A,
|
||||
tensors.B,
|
||||
@ -101,21 +190,26 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
tensors.num_expert_tokens,
|
||||
compute_tl_dtype,
|
||||
# Quantization data
|
||||
None,
|
||||
None,
|
||||
A_scale,
|
||||
B_scale,
|
||||
None,
|
||||
# Quantization schemes
|
||||
dtype == torch.torch.float8_e4m3fn,
|
||||
use_fp8_w8a8,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16
|
||||
"BLOCK_SIZE_M": block_shape[0],
|
||||
"BLOCK_SIZE_N": block_shape[1],
|
||||
"BLOCK_SIZE_K": block_shape[2],
|
||||
})
|
||||
|
||||
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
||||
tensors.num_expert_tokens)
|
||||
ref_output = ref_impl(tensors.A,
|
||||
tensors.B,
|
||||
ref_output,
|
||||
tensors.num_expert_tokens,
|
||||
A_scale,
|
||||
B_scale,
|
||||
block_shape[-2:])
|
||||
|
||||
rtol, atol = {
|
||||
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
|
||||
|
||||
@ -316,8 +316,8 @@ def invoke_moe_batched_triton_kernel(
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
@ -500,15 +500,16 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_m: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_shape is None
|
||||
#assert block_shape is None
|
||||
assert block_m is None
|
||||
assert not use_fp8_w8a8, "NYI"
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.block_shape = block_shape
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -528,6 +529,66 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2 = max_num_tokens * num_dp * N
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
|
||||
def native_w8a8_block_matmul(A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor):
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization using native torch.
|
||||
It is agnostic to the input data type and can be used for both int8 and
|
||||
fp8 data types.
|
||||
|
||||
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||
`Bs` (float32).
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert self.block_shape is not None and len(self.block_shape) == 2
|
||||
block_n, block_k = self.block_shape[0], self.block_shape[1]
|
||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (N, )
|
||||
A = A.reshape(M, A.shape[-1])
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
|
||||
A_tiles = [
|
||||
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||
]
|
||||
B_tiles = [[
|
||||
B[
|
||||
j * block_n:min((j + 1) * block_n, N),
|
||||
i * block_k:min((i + 1) * block_k, K),
|
||||
] for i in range(k_tiles)
|
||||
] for j in range(n_tiles)]
|
||||
C_tiles = [
|
||||
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||
]
|
||||
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||
|
||||
for i in range(k_tiles):
|
||||
for j in range(n_tiles):
|
||||
a = A_tiles[i]
|
||||
b = B_tiles[j][i]
|
||||
c = C_tiles[j]
|
||||
s = As_tiles[i] * Bs[j][i]
|
||||
c[:, :] += torch.matmul(a, b.t()) * s
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
def apply(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -580,9 +641,15 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
else:
|
||||
num = int(expert_num_tokens[expert].item())
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
if self.use_fp8_w8a8:
|
||||
assert False # TBD
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
else:
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
|
||||
return out
|
||||
|
||||
@ -732,7 +799,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
#qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
|
||||
# TODO (varun) : support w8a8
|
||||
#assert not self.use_fp8_w8a8
|
||||
if self.use_fp8_w8a8:
|
||||
@ -753,6 +820,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_act_token, self.block_shape)
|
||||
else:
|
||||
qintermediate_cache2 = intermediate_cache2
|
||||
a2q_scale = a2_scale
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||
B=w2,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user