Use pytest format for unit tests (#107)

This commit is contained in:
Woosuk Kwon 2023-05-17 17:11:23 -07:00 committed by GitHub
parent b322fd1607
commit 825d8892b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 43 additions and 37 deletions

View File

@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
@torch.inference_mode()
def test_silu_and_mul(
def run_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
@ -22,9 +22,9 @@ def test_silu_and_mul(
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
def test_silu_and_mul() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
test_silu_and_mul(num_tokens, d, dtype)
run_silu_and_mul(num_tokens, d, dtype)

View File

@ -8,6 +8,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from cacheflow import attention_ops
MAX_SEQ_LEN = 4096
TEST_SEED = 0
def ref_masked_attention(
@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention(
return ref_output
def test_single_query_cached_kv_attention(
@torch.inference_mode()
def run_single_query_cached_kv_attention(
num_tokens: int,
num_heads: int,
head_size: int,
@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def test_multi_query_kv_attention(
@torch.inference_mode()
def run_multi_query_kv_attention(
num_seqs: int,
num_heads: int,
head_size: int,
@ -264,19 +267,16 @@ def test_multi_query_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
# the test fails due to the precision issue. Re-run the test if it fails.
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]:
for block_size in [8, 16, 32, 64]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_single_query_cached_kv_attention(
run_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
head_size=head_size,
@ -285,17 +285,17 @@ def test_attention(seed: int) -> None:
dtype=dtype,
)
def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
test_multi_query_kv_attention(
run_multi_query_kv_attention(
num_seqs=5,
num_heads=3,
head_size=head_size,
dtype=dtype,
)
if __name__ == '__main__':
test_attention(seed=0)

View File

@ -5,7 +5,8 @@ import torch
from cacheflow import cache_ops
def test_copy_blocks(
@torch.inference_mode()
def run_copy_blocks(
num_mappings: int,
num_layers: int,
num_heads: int,
@ -60,7 +61,8 @@ def test_copy_blocks(
assert torch.allclose(value_cache, cloned_value_cache)
def test_reshape_and_cache(
@torch.inference_mode()
def run_reshape_and_cache(
num_tokens: int,
num_heads: int,
head_size: int,
@ -99,7 +101,8 @@ def test_reshape_and_cache(
assert torch.allclose(value_cache, cloned_value_cache)
def test_gather_cached_kv(
@torch.inference_mode()
def run_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
@ -140,19 +143,22 @@ def test_gather_cached_kv(
assert torch.allclose(value, cloned_value)
@torch.inference_mode()
def test_cache() -> None:
def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
test_copy_blocks(
run_copy_blocks(
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
block_size=8, num_blocks=1024, dtype=dtype)
test_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
test_gather_cached_kv(
def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)
if __name__ == '__main__':
test_cache()
def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv(
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
dtype=dtype)

View File

@ -22,7 +22,7 @@ class RefRMSNorm(nn.Module):
@torch.inference_mode()
def test_rms_norm(
def run_rms_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
@ -41,13 +41,13 @@ def test_rms_norm(
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
if __name__ == '__main__':
def test_rms_norm() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
test_rms_norm(
run_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,

View File

@ -76,7 +76,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
@torch.inference_mode()
def test_rotary_embedding_neox(
def run_rotary_embedding_neox(
num_tokens: int,
num_heads: int,
head_size: int,
@ -128,15 +128,15 @@ def test_rotary_embedding_neox(
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
if __name__ == '__main__':
def test_rotary_embedding_neox() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
test_rotary_embedding_neox(
run_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
rotary_dim=int(head_size * 0.25),
rotary_dim=head_size,
dtype=dtype,
)