Categorize tests/kernels/ based on kernel type (#16799)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-04-23 07:21:07 -06:00 committed by GitHub
parent aa72d9a4ea
commit 6317a5174a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 80 additions and 49 deletions

View File

@ -16,7 +16,7 @@ import numpy
import pytest
import yaml
RTOL = 0.05
RTOL = 0.08
TEST_DATA_FILE = os.environ.get(
"LM_EVAL_TEST_DATA_FILE",
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")

View File

@ -317,15 +317,46 @@ steps:
commands:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 1h each
mirror_hardwares: [amd]
- label: Kernels Core Operation Test
source_file_dependencies:
- csrc/
- vllm/attention
- tests/kernels
- tests/kernels/core
commands:
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
- pytest -v -s kernels/core
- label: Kernels Attention Test %N
source_file_dependencies:
- csrc/attention/
- vllm/attention
- vllm/v1/attention
- tests/kernels/attention
commands:
- pytest -v -s kernels/attention --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 2
- label: Kernels Quantization Test %N
source_file_dependencies:
- csrc/quantization/
- vllm/model_executor/layers/quantization
- tests/kernels/quantization
commands:
- pytest -v -s kernels/quantization --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 2
- label: Kernels MoE Test
source_file_dependencies:
- csrc/moe/
- tests/kernels/moe
- vllm/model_executor/layers/fused_moe/
commands:
- pytest -v -s kernels/moe
- label: Kernels Mamba Test
source_file_dependencies:
- csrc/mamba/
- tests/kernels/mamba
commands:
- pytest -v -s kernels/mamba
- label: Tensorizer Test # 11min
# mirror_hardwares: [amd]

View File

@ -6,13 +6,12 @@ from typing import Optional
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes
from .allclose_default import get_default_atol, get_default_rtol
if not current_platform.is_rocm():
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

View File

@ -156,6 +156,15 @@ def test_env(
expected = ("TRITON_MLA_VLLM_V1"
if use_v1 else "TRITON_MLA")
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,

View File

@ -6,14 +6,13 @@ from typing import Optional
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm import _custom_ops as ops
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn)
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes
from .allclose_default import get_default_atol, get_default_rtol
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer

View File

@ -5,6 +5,7 @@ import random
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, MulAndSilu,
@ -12,8 +13,6 @@ from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
SiluAndMul)
from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 13824] # Arbitrary values for testing

View File

@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""
import torch
from tests.kernels.utils import opcheck
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
# TODO: Add this back, currently fails with
# csrc/cuda_utils_kernels.cu:15 'invalid argument'
# @pytest.mark.skipif(not current_platform.is_cuda(),
# reason="Only supported for CUDA")
# def test_cuda_utils_opcheck():
# opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
# opcheck(
# torch.ops._C_cuda_utils.
# get_max_shared_memory_per_block_device_attribute, (0, ))

View File

@ -6,11 +6,10 @@ from typing import Callable, Optional
import pytest
import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 112, 120, 256]

View File

@ -6,6 +6,7 @@ import itertools
import pytest
import torch
from tests.kernels.utils_block import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
@ -18,8 +19,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
from .utils_block import native_w8a8_block_matmul
dg_available = False
try:
import deep_gemm

View File

@ -6,6 +6,7 @@ import itertools
import pytest
import torch
from tests.kernels.utils_block import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
@ -13,8 +14,6 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul)
from vllm.platforms import current_platform
from .utils_block import native_w8a8_block_matmul
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)

View File

@ -7,13 +7,12 @@ Run `pytest tests/kernels/test_semi_structured.py`.
import pytest
import torch
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
from .utils import baseline_scaled_mm, to_fp8, to_int8
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

View File

@ -8,13 +8,11 @@ import random
import pytest
import torch
from tests.kernels.utils import opcheck
from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import cdiv
from .utils import baseline_scaled_mm, to_fp8, to_int8
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),

View File

@ -1,25 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for miscellaneous utilities
"""
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
def test_convert_fp8_opcheck():
data = torch.randn((256, 256), dtype=torch.float32, device="cuda")
result = torch.empty_like(data, dtype=torch.float8_e4m3fn)
opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8"))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="Only supported for CUDA")
def test_cuda_utils_opcheck():
opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0))
opcheck(
torch.ops._C_cuda_utils.
get_max_shared_memory_per_block_device_attribute, (0, ))