mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 14:47:06 +08:00
wip
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
25ed6738d4
commit
bbe888d033
@ -136,16 +136,18 @@ def ref_impl(
|
|||||||
for e in range(num_experts):
|
for e in range(num_experts):
|
||||||
num_tokens = num_expert_tokens_cpu[e]
|
num_tokens = num_expert_tokens_cpu[e]
|
||||||
if A.dtype == torch.torch.float8_e4m3fn:
|
if A.dtype == torch.torch.float8_e4m3fn:
|
||||||
C[e, :, :] = native_w8a8_block_matmul(A[e, :, :],
|
tmp = native_w8a8_block_matmul(A[e, :, :],
|
||||||
B[e].transpose(0, 1),
|
B[e].transpose(0, 1),
|
||||||
A_scale,
|
A_scale,
|
||||||
B_scale,
|
B_scale,
|
||||||
[1,1])#block_shape)
|
[1,1])#block_shape)
|
||||||
|
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||||
else:
|
else:
|
||||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||||
|
|
||||||
return C
|
return C
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_experts", [16, 32])
|
@pytest.mark.parametrize("num_experts", [16, 32])
|
||||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||||
[32, 64, 128, 192, 224, 256, 512])
|
[32, 64, 128, 192, 224, 256, 512])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user