basic working test

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-21 23:20:54 +00:00
parent 77f95b99a6
commit 9cfebf51ba
2 changed files with 54 additions and 51 deletions

View File

@ -198,9 +198,11 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
quant_block_shape = [1, 1]
else:
A_scale = None
B_scale = None
quant_block_shape = None
invoke_moe_batched_triton_kernel(
tensors.A,
@ -220,7 +222,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
"BLOCK_SIZE_M": block_shape[0],
"BLOCK_SIZE_N": block_shape[1],
"BLOCK_SIZE_K": block_shape[2],
})
},
block_shape=quant_block_shape,
)
ref_output = ref_output.to(dtype=out_dtype)
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
@ -246,5 +250,4 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
}[test_output.dtype]
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
if not use_fp8_w8a8:
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)

View File

@ -15,38 +15,38 @@ from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
@triton.jit
def moe_mmk(
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
mask_m,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
compute_type: tl.constexpr,
use_w8a8: tl.constexpr,
use_w8a16: tl.constexpr):
a_ptrs,
b_ptrs,
K,
expert_id,
a_scale_ptr,
b_scale_ptr,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_ak,
stride_bk,
stride_asm,
stride_ask,
stride_bse,
stride_bsk,
stride_bsn,
# Offsets and masks
offs_m,
offs_n,
mask_m,
# Block size for block-wise quantization
group_n: tl.constexpr,
group_k: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
compute_type: tl.constexpr,
use_w8a8: tl.constexpr,
use_w8a16: tl.constexpr
):
offs_k = tl.arange(0, BLOCK_K)
if use_w8a16:
@ -310,22 +310,22 @@ def batched_triton_kernel(
def invoke_moe_batched_triton_kernel(
A: torch.Tensor, # [E, max_tokens, K]
B: torch.Tensor, # [E, K, N]
C: torch.Tensor, # [E, max_tokens, N]
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
config: dict[str, int],
block_shape: Optional[list[int]] = None):
A: torch.Tensor, # [E, max_tokens, K]
B: torch.Tensor, # [E, K, N]
C: torch.Tensor, # [E, max_tokens, N]
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
config: dict[str, int],
block_shape: Optional[list[int]] = None
):
assert not use_int4_w4a16
max_num_tokens = A.size(1)
K = A.size(2)