mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 12:07:08 +08:00
test
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
bbe888d033
commit
77f95b99a6
@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchedMMConfig:
|
class BatchedMMConfig:
|
||||||
dtype: torch.dtype
|
in_dtype: torch.dtype
|
||||||
|
out_dtype: torch.dtype
|
||||||
num_experts: int
|
num_experts: int
|
||||||
max_tokens_per_expert: int
|
max_tokens_per_expert: int
|
||||||
K: int
|
K: int
|
||||||
@ -29,26 +30,25 @@ class BatchedMMTensors:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_tensors(config: BatchedMMConfig):
|
def make_tensors(config: BatchedMMConfig):
|
||||||
if config.dtype == torch.torch.float8_e4m3fn:
|
if config.in_dtype == torch.torch.float8_e4m3fn:
|
||||||
config_dtype = torch.bfloat16
|
config_in_dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
config_dtype = config.dtype
|
config_in_dtype = config.in_dtype
|
||||||
|
|
||||||
A = torch.randn(
|
A = torch.randn(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config_dtype) / 10
|
dtype=config_in_dtype) / 10
|
||||||
B = torch.randn((config.num_experts, config.N, config.K),
|
B = torch.randn((config.num_experts, config.N, config.K),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config_dtype)
|
dtype=config_in_dtype)
|
||||||
C = torch.zeros(
|
C = torch.zeros(
|
||||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=config_dtype)
|
dtype=config.out_dtype)
|
||||||
|
|
||||||
A = A.to(config.dtype)
|
A = A.to(config.in_dtype)
|
||||||
B = B.to(config.dtype)
|
B = B.to(config.in_dtype)
|
||||||
C = C.to(config.dtype)
|
|
||||||
|
|
||||||
num_expert_tokens = torch.randint(low=0,
|
num_expert_tokens = torch.randint(low=0,
|
||||||
high=config.max_tokens_per_expert,
|
high=config.max_tokens_per_expert,
|
||||||
@ -136,11 +136,19 @@ def ref_impl(
|
|||||||
for e in range(num_experts):
|
for e in range(num_experts):
|
||||||
num_tokens = num_expert_tokens_cpu[e]
|
num_tokens = num_expert_tokens_cpu[e]
|
||||||
if A.dtype == torch.torch.float8_e4m3fn:
|
if A.dtype == torch.torch.float8_e4m3fn:
|
||||||
tmp = native_w8a8_block_matmul(A[e, :, :],
|
if False:
|
||||||
B[e].transpose(0, 1),
|
tmp = native_w8a8_block_matmul(A[e, :, :],
|
||||||
A_scale,
|
B[e].transpose(0, 1),
|
||||||
B_scale,
|
A_scale,
|
||||||
[1,1])#block_shape)
|
B_scale,
|
||||||
|
[1,1])#block_shape)
|
||||||
|
else:
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
tmp = ops.cutlass_scaled_mm(A[e, :, :],
|
||||||
|
B[e].transpose(0, 1),
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
C.dtype)
|
||||||
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||||
else:
|
else:
|
||||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||||
@ -159,14 +167,21 @@ def ref_impl(
|
|||||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||||
N: int, dtype: torch.dtype):
|
N: int, dtype: torch.dtype):
|
||||||
|
|
||||||
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
if dtype == torch.torch.float8_e4m3fn:
|
||||||
|
in_dtype = dtype
|
||||||
|
out_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
in_dtype = dtype
|
||||||
|
out_dtype = dtype
|
||||||
|
|
||||||
|
config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N)
|
||||||
tensors = BatchedMMTensors.make_tensors(config)
|
tensors = BatchedMMTensors.make_tensors(config)
|
||||||
|
|
||||||
test_output = tensors.C
|
test_output = tensors.C
|
||||||
ref_output = test_output.clone()
|
ref_output = test_output.clone()
|
||||||
|
ref_output2 = test_output.clone()
|
||||||
|
|
||||||
compute_tl_dtype = {
|
compute_tl_dtype = {
|
||||||
torch.torch.float8_e4m3fn: tl.bfloat16,
|
|
||||||
torch.float16: tl.float16,
|
torch.float16: tl.float16,
|
||||||
torch.bfloat16: tl.bfloat16,
|
torch.bfloat16: tl.bfloat16,
|
||||||
torch.float32: tl.float32
|
torch.float32: tl.float32
|
||||||
@ -175,12 +190,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
|
||||||
block_shape = [16, 16, 32] # 16 for k if not fp8
|
block_shape = [16, 16, 32] # 16 for k if not fp8
|
||||||
|
|
||||||
print(f"tensors.A {tensors.A.shape}")
|
#print(f"tensors.A {tensors.A.shape}")
|
||||||
print(f"tensors.B {tensors.B.shape}")
|
#print(f"tensors.B {tensors.B.shape}")
|
||||||
|
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
|
#A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
|
||||||
B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
||||||
|
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
||||||
|
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
||||||
else:
|
else:
|
||||||
A_scale = None
|
A_scale = None
|
||||||
B_scale = None
|
B_scale = None
|
||||||
@ -205,19 +222,29 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
"BLOCK_SIZE_K": block_shape[2],
|
"BLOCK_SIZE_K": block_shape[2],
|
||||||
})
|
})
|
||||||
|
|
||||||
ref_output = ref_impl(tensors.A,
|
ref_output = ref_output.to(dtype=out_dtype)
|
||||||
tensors.B,
|
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
|
||||||
|
tensors.B.to(dtype=out_dtype),
|
||||||
ref_output,
|
ref_output,
|
||||||
tensors.num_expert_tokens,
|
tensors.num_expert_tokens,
|
||||||
A_scale,
|
A_scale,
|
||||||
B_scale,
|
B_scale,
|
||||||
block_shape[-2:])
|
block_shape[-2:])
|
||||||
|
|
||||||
|
ref_output2 = ref_impl(tensors.A,
|
||||||
|
tensors.B,
|
||||||
|
ref_output2,
|
||||||
|
tensors.num_expert_tokens,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
block_shape[-2:])
|
||||||
|
|
||||||
rtol, atol = {
|
rtol, atol = {
|
||||||
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
|
|
||||||
torch.float16: (6e-2, 6e-2),
|
torch.float16: (6e-2, 6e-2),
|
||||||
torch.bfloat16: (6e-2, 6e-2),
|
torch.bfloat16: (6e-2, 6e-2),
|
||||||
torch.float32: (1e-2, 1e-2),
|
torch.float32: (1e-2, 1e-2),
|
||||||
}[test_output.dtype]
|
}[test_output.dtype]
|
||||||
|
|
||||||
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
||||||
|
if not use_fp8_w8a8:
|
||||||
|
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user