Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-21 23:13:43 +00:00
parent bbe888d033
commit 77f95b99a6

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
@dataclass @dataclass
class BatchedMMConfig: class BatchedMMConfig:
dtype: torch.dtype in_dtype: torch.dtype
out_dtype: torch.dtype
num_experts: int num_experts: int
max_tokens_per_expert: int max_tokens_per_expert: int
K: int K: int
@ -29,26 +30,25 @@ class BatchedMMTensors:
@staticmethod @staticmethod
def make_tensors(config: BatchedMMConfig): def make_tensors(config: BatchedMMConfig):
if config.dtype == torch.torch.float8_e4m3fn: if config.in_dtype == torch.torch.float8_e4m3fn:
config_dtype = torch.bfloat16 config_in_dtype = torch.bfloat16
else: else:
config_dtype = config.dtype config_in_dtype = config.in_dtype
A = torch.randn( A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K), (config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda", device="cuda",
dtype=config_dtype) / 10 dtype=config_in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K), B = torch.randn((config.num_experts, config.N, config.K),
device="cuda", device="cuda",
dtype=config_dtype) dtype=config_in_dtype)
C = torch.zeros( C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N), (config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda", device="cuda",
dtype=config_dtype) dtype=config.out_dtype)
A = A.to(config.dtype) A = A.to(config.in_dtype)
B = B.to(config.dtype) B = B.to(config.in_dtype)
C = C.to(config.dtype)
num_expert_tokens = torch.randint(low=0, num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert, high=config.max_tokens_per_expert,
@ -136,11 +136,19 @@ def ref_impl(
for e in range(num_experts): for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e] num_tokens = num_expert_tokens_cpu[e]
if A.dtype == torch.torch.float8_e4m3fn: if A.dtype == torch.torch.float8_e4m3fn:
tmp = native_w8a8_block_matmul(A[e, :, :], if False:
B[e].transpose(0, 1), tmp = native_w8a8_block_matmul(A[e, :, :],
A_scale, B[e].transpose(0, 1),
B_scale, A_scale,
[1,1])#block_shape) B_scale,
[1,1])#block_shape)
else:
import vllm._custom_ops as ops
tmp = ops.cutlass_scaled_mm(A[e, :, :],
B[e].transpose(0, 1),
A_scale,
B_scale,
C.dtype)
C[e, :num_tokens, :] = tmp[:num_tokens, :] C[e, :num_tokens, :] = tmp[:num_tokens, :]
else: else:
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
@ -159,14 +167,21 @@ def ref_impl(
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype): N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) if dtype == torch.torch.float8_e4m3fn:
in_dtype = dtype
out_dtype = torch.bfloat16
else:
in_dtype = dtype
out_dtype = dtype
config = BatchedMMConfig(in_dtype, out_dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config) tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C test_output = tensors.C
ref_output = test_output.clone() ref_output = test_output.clone()
ref_output2 = test_output.clone()
compute_tl_dtype = { compute_tl_dtype = {
torch.torch.float8_e4m3fn: tl.bfloat16,
torch.float16: tl.float16, torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16, torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32 torch.float32: tl.float32
@ -175,12 +190,14 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
block_shape = [16, 16, 32] # 16 for k if not fp8 block_shape = [16, 16, 32] # 16 for k if not fp8
print(f"tensors.A {tensors.A.shape}") #print(f"tensors.A {tensors.A.shape}")
print(f"tensors.B {tensors.B.shape}") #print(f"tensors.B {tensors.B.shape}")
if use_fp8_w8a8: if use_fp8_w8a8:
A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device) #A_scale = torch.ones((max_tokens_per_expert,K), dtype=torch.float32, device=tensors.A.device)
B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device) #B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
else: else:
A_scale = None A_scale = None
B_scale = None B_scale = None
@ -205,19 +222,29 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
"BLOCK_SIZE_K": block_shape[2], "BLOCK_SIZE_K": block_shape[2],
}) })
ref_output = ref_impl(tensors.A, ref_output = ref_output.to(dtype=out_dtype)
tensors.B, ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
tensors.B.to(dtype=out_dtype),
ref_output, ref_output,
tensors.num_expert_tokens, tensors.num_expert_tokens,
A_scale, A_scale,
B_scale, B_scale,
block_shape[-2:]) block_shape[-2:])
ref_output2 = ref_impl(tensors.A,
tensors.B,
ref_output2,
tensors.num_expert_tokens,
A_scale,
B_scale,
block_shape[-2:])
rtol, atol = { rtol, atol = {
torch.torch.float8_e4m3fn: (6e-2, 6e-2),
torch.float16: (6e-2, 6e-2), torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2), torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2), torch.float32: (1e-2, 1e-2),
}[test_output.dtype] }[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
if not use_fp8_w8a8:
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)