mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 10:27:04 +08:00
basic working test
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
77f95b99a6
commit
9cfebf51ba
@ -198,9 +198,11 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
||||
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
||||
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
||||
quant_block_shape = [1, 1]
|
||||
else:
|
||||
A_scale = None
|
||||
B_scale = None
|
||||
quant_block_shape = None
|
||||
|
||||
invoke_moe_batched_triton_kernel(
|
||||
tensors.A,
|
||||
@ -220,7 +222,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
"BLOCK_SIZE_M": block_shape[0],
|
||||
"BLOCK_SIZE_N": block_shape[1],
|
||||
"BLOCK_SIZE_K": block_shape[2],
|
||||
})
|
||||
},
|
||||
block_shape=quant_block_shape,
|
||||
)
|
||||
|
||||
ref_output = ref_output.to(dtype=out_dtype)
|
||||
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
|
||||
@ -246,5 +250,4 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
}[test_output.dtype]
|
||||
|
||||
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
||||
if not use_fp8_w8a8:
|
||||
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
||||
|
||||
@ -15,38 +15,38 @@ from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
||||
|
||||
@triton.jit
|
||||
def moe_mmk(
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_w8a8: tl.constexpr,
|
||||
use_w8a16: tl.constexpr):
|
||||
|
||||
a_ptrs,
|
||||
b_ptrs,
|
||||
K,
|
||||
expert_id,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_asm,
|
||||
stride_ask,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
# Offsets and masks
|
||||
offs_m,
|
||||
offs_n,
|
||||
mask_m,
|
||||
# Block size for block-wise quantization
|
||||
group_n: tl.constexpr,
|
||||
group_k: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_w8a8: tl.constexpr,
|
||||
use_w8a16: tl.constexpr
|
||||
):
|
||||
offs_k = tl.arange(0, BLOCK_K)
|
||||
|
||||
if use_w8a16:
|
||||
@ -310,22 +310,22 @@ def batched_triton_kernel(
|
||||
|
||||
|
||||
def invoke_moe_batched_triton_kernel(
|
||||
A: torch.Tensor, # [E, max_tokens, K]
|
||||
B: torch.Tensor, # [E, K, N]
|
||||
C: torch.Tensor, # [E, max_tokens, N]
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
A: torch.Tensor, # [E, max_tokens, K]
|
||||
B: torch.Tensor, # [E, K, N]
|
||||
C: torch.Tensor, # [E, max_tokens, N]
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
config: dict[str, int],
|
||||
block_shape: Optional[list[int]] = None
|
||||
):
|
||||
assert not use_int4_w4a16
|
||||
max_num_tokens = A.size(1)
|
||||
K = A.size(2)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user