mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 05:22:16 +08:00
Add missing rocm_skinny_gemms kernel test to CI (#17060)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
67309a1cb5
commit
82e43b2d7e
@ -87,3 +87,63 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
|||||||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||||
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
|
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
|
||||||
return ref_out, ref_scale.view((1, ))
|
return ref_out, ref_scale.view((1, ))
|
||||||
|
|
||||||
|
|
||||||
|
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||||
|
As: torch.Tensor, Bs: torch.Tensor, block_size,
|
||||||
|
output_dtype):
|
||||||
|
"""This function performs matrix multiplication with block-wise
|
||||||
|
quantization using native torch.
|
||||||
|
It is agnostic to the input data type and can be used for both int8 and
|
||||||
|
fp8 data types.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
||||||
|
`Bs` (float32).
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
"""
|
||||||
|
A = A.to(torch.float32)
|
||||||
|
B = B.to(torch.float32)
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
||||||
|
assert A.shape[:-1] == As.shape[:-1]
|
||||||
|
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
N, K = B.shape
|
||||||
|
origin_C_shape = A.shape[:-1] + (N, )
|
||||||
|
A = A.reshape(M, A.shape[-1])
|
||||||
|
As = As.reshape(M, As.shape[-1])
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
assert n_tiles == Bs.shape[0]
|
||||||
|
assert k_tiles == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = (M, N)
|
||||||
|
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||||
|
|
||||||
|
A_tiles = [
|
||||||
|
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||||
|
]
|
||||||
|
B_tiles = [[
|
||||||
|
B[
|
||||||
|
j * block_n:min((j + 1) * block_n, N),
|
||||||
|
i * block_k:min((i + 1) * block_k, K),
|
||||||
|
] for i in range(k_tiles)
|
||||||
|
] for j in range(n_tiles)]
|
||||||
|
C_tiles = [
|
||||||
|
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
||||||
|
]
|
||||||
|
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
||||||
|
|
||||||
|
for i in range(k_tiles):
|
||||||
|
for j in range(n_tiles):
|
||||||
|
a = A_tiles[i]
|
||||||
|
b = B_tiles[j][i]
|
||||||
|
c = C_tiles[j]
|
||||||
|
s = As_tiles[i] * Bs[j][i]
|
||||||
|
c[:, :] += torch.matmul(a, b.t()) * s
|
||||||
|
|
||||||
|
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||||
|
return C
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import itertools
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils_block import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import itertools
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.kernels.utils_block import native_w8a8_block_matmul
|
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|||||||
@ -1,63 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
|
||||||
As: torch.Tensor, Bs: torch.Tensor, block_size,
|
|
||||||
output_dtype):
|
|
||||||
"""This function performs matrix multiplication with block-wise
|
|
||||||
quantization using native torch.
|
|
||||||
It is agnostic to the input data type and can be used for both int8 and
|
|
||||||
fp8 data types.
|
|
||||||
|
|
||||||
It takes two input tensors `A` and `B` (int8) with scales `As` and
|
|
||||||
`Bs` (float32).
|
|
||||||
The output is returned in the specified `output_dtype`.
|
|
||||||
"""
|
|
||||||
A = A.to(torch.float32)
|
|
||||||
B = B.to(torch.float32)
|
|
||||||
assert A.shape[-1] == B.shape[-1]
|
|
||||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
|
||||||
assert len(block_size) == 2
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
|
|
||||||
assert A.shape[:-1] == As.shape[:-1]
|
|
||||||
|
|
||||||
M = A.numel() // A.shape[-1]
|
|
||||||
N, K = B.shape
|
|
||||||
origin_C_shape = A.shape[:-1] + (N, )
|
|
||||||
A = A.reshape(M, A.shape[-1])
|
|
||||||
As = As.reshape(M, As.shape[-1])
|
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
|
||||||
k_tiles = (K + block_k - 1) // block_k
|
|
||||||
assert n_tiles == Bs.shape[0]
|
|
||||||
assert k_tiles == Bs.shape[1]
|
|
||||||
|
|
||||||
C_shape = (M, N)
|
|
||||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
|
||||||
|
|
||||||
A_tiles = [
|
|
||||||
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
|
||||||
]
|
|
||||||
B_tiles = [[
|
|
||||||
B[
|
|
||||||
j * block_n:min((j + 1) * block_n, N),
|
|
||||||
i * block_k:min((i + 1) * block_k, K),
|
|
||||||
] for i in range(k_tiles)
|
|
||||||
] for j in range(n_tiles)]
|
|
||||||
C_tiles = [
|
|
||||||
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
|
|
||||||
]
|
|
||||||
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
|
|
||||||
|
|
||||||
for i in range(k_tiles):
|
|
||||||
for j in range(n_tiles):
|
|
||||||
a = A_tiles[i]
|
|
||||||
b = B_tiles[j][i]
|
|
||||||
c = C_tiles[j]
|
|
||||||
s = As_tiles[i] * Bs[j][i]
|
|
||||||
c[:, :] += torch.matmul(a, b.t()) * s
|
|
||||||
|
|
||||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
|
||||||
return C
|
|
||||||
Loading…
x
Reference in New Issue
Block a user