mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 17:47:52 +08:00
basic working test
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
77f95b99a6
commit
9cfebf51ba
@ -198,9 +198,11 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
#B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device)
|
||||||
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
|
||||||
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
|
||||||
|
quant_block_shape = [1, 1]
|
||||||
else:
|
else:
|
||||||
A_scale = None
|
A_scale = None
|
||||||
B_scale = None
|
B_scale = None
|
||||||
|
quant_block_shape = None
|
||||||
|
|
||||||
invoke_moe_batched_triton_kernel(
|
invoke_moe_batched_triton_kernel(
|
||||||
tensors.A,
|
tensors.A,
|
||||||
@ -220,7 +222,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
"BLOCK_SIZE_M": block_shape[0],
|
"BLOCK_SIZE_M": block_shape[0],
|
||||||
"BLOCK_SIZE_N": block_shape[1],
|
"BLOCK_SIZE_N": block_shape[1],
|
||||||
"BLOCK_SIZE_K": block_shape[2],
|
"BLOCK_SIZE_K": block_shape[2],
|
||||||
})
|
},
|
||||||
|
block_shape=quant_block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
ref_output = ref_output.to(dtype=out_dtype)
|
ref_output = ref_output.to(dtype=out_dtype)
|
||||||
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
|
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
|
||||||
@ -246,5 +250,4 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
}[test_output.dtype]
|
}[test_output.dtype]
|
||||||
|
|
||||||
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
|
||||||
if not use_fp8_w8a8:
|
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
||||||
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
|
|
||||||
|
|||||||
@ -15,38 +15,38 @@ from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize,
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def moe_mmk(
|
def moe_mmk(
|
||||||
a_ptrs,
|
a_ptrs,
|
||||||
b_ptrs,
|
b_ptrs,
|
||||||
K,
|
K,
|
||||||
expert_id,
|
expert_id,
|
||||||
a_scale_ptr,
|
a_scale_ptr,
|
||||||
b_scale_ptr,
|
b_scale_ptr,
|
||||||
# The stride variables represent how much to increase the ptr by when
|
# The stride variables represent how much to increase the ptr by when
|
||||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
# how much to increase `a_ptr` by to get the element one row down
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
# (A has M rows).
|
# (A has M rows).
|
||||||
stride_ak,
|
stride_ak,
|
||||||
stride_bk,
|
stride_bk,
|
||||||
stride_asm,
|
stride_asm,
|
||||||
stride_ask,
|
stride_ask,
|
||||||
stride_bse,
|
stride_bse,
|
||||||
stride_bsk,
|
stride_bsk,
|
||||||
stride_bsn,
|
stride_bsn,
|
||||||
# Offsets and masks
|
# Offsets and masks
|
||||||
offs_m,
|
offs_m,
|
||||||
offs_n,
|
offs_n,
|
||||||
mask_m,
|
mask_m,
|
||||||
# Block size for block-wise quantization
|
# Block size for block-wise quantization
|
||||||
group_n: tl.constexpr,
|
group_n: tl.constexpr,
|
||||||
group_k: tl.constexpr,
|
group_k: tl.constexpr,
|
||||||
# Meta-parameters
|
# Meta-parameters
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_K: tl.constexpr,
|
BLOCK_K: tl.constexpr,
|
||||||
compute_type: tl.constexpr,
|
compute_type: tl.constexpr,
|
||||||
use_w8a8: tl.constexpr,
|
use_w8a8: tl.constexpr,
|
||||||
use_w8a16: tl.constexpr):
|
use_w8a16: tl.constexpr
|
||||||
|
):
|
||||||
offs_k = tl.arange(0, BLOCK_K)
|
offs_k = tl.arange(0, BLOCK_K)
|
||||||
|
|
||||||
if use_w8a16:
|
if use_w8a16:
|
||||||
@ -310,22 +310,22 @@ def batched_triton_kernel(
|
|||||||
|
|
||||||
|
|
||||||
def invoke_moe_batched_triton_kernel(
|
def invoke_moe_batched_triton_kernel(
|
||||||
A: torch.Tensor, # [E, max_tokens, K]
|
A: torch.Tensor, # [E, max_tokens, K]
|
||||||
B: torch.Tensor, # [E, K, N]
|
B: torch.Tensor, # [E, K, N]
|
||||||
C: torch.Tensor, # [E, max_tokens, N]
|
C: torch.Tensor, # [E, max_tokens, N]
|
||||||
expert_num_tokens: torch.Tensor, # [E]
|
expert_num_tokens: torch.Tensor, # [E]
|
||||||
compute_type: tl.dtype,
|
compute_type: tl.dtype,
|
||||||
# Quantization data
|
# Quantization data
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
B_scale: Optional[torch.Tensor],
|
B_scale: Optional[torch.Tensor],
|
||||||
B_zp: torch.Tensor,
|
B_zp: torch.Tensor,
|
||||||
# Quantization schemes
|
# Quantization schemes
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool,
|
||||||
config: dict[str, int],
|
config: dict[str, int],
|
||||||
block_shape: Optional[list[int]] = None):
|
block_shape: Optional[list[int]] = None
|
||||||
|
):
|
||||||
assert not use_int4_w4a16
|
assert not use_int4_w4a16
|
||||||
max_num_tokens = A.size(1)
|
max_num_tokens = A.size(1)
|
||||||
K = A.size(2)
|
K = A.size(2)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user