mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
Use pytest format for unit tests (#107)
This commit is contained in:
parent
b322fd1607
commit
825d8892b5
@ -10,7 +10,7 @@ def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
def run_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
@ -22,9 +22,9 @@ def test_silu_and_mul(
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def test_silu_and_mul() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 13824]:
|
||||
for d in [512, 4096, 5120, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
test_silu_and_mul(num_tokens, d, dtype)
|
||||
run_silu_and_mul(num_tokens, d, dtype)
|
||||
@ -8,6 +8,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
from cacheflow import attention_ops
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
TEST_SEED = 0
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@ -155,7 +156,8 @@ def ref_multi_query_cached_kv_attention(
|
||||
return ref_output
|
||||
|
||||
|
||||
def test_single_query_cached_kv_attention(
|
||||
@torch.inference_mode()
|
||||
def run_single_query_cached_kv_attention(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -223,7 +225,8 @@ def test_single_query_cached_kv_attention(
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
def test_multi_query_kv_attention(
|
||||
@torch.inference_mode()
|
||||
def run_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -264,19 +267,16 @@ def test_multi_query_kv_attention(
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_attention(seed: int) -> None:
|
||||
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
|
||||
# the test fails due to the precision issue. Re-run the test if it fails.
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
def test_single_query_cached_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for block_size in [8, 16, 32, 64]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
test_single_query_cached_kv_attention(
|
||||
run_single_query_cached_kv_attention(
|
||||
num_tokens=37,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
@ -285,17 +285,17 @@ def test_attention(seed: int) -> None:
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
def test_multi_query_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
test_multi_query_kv_attention(
|
||||
run_multi_query_kv_attention(
|
||||
num_seqs=5,
|
||||
num_heads=3,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_attention(seed=0)
|
||||
@ -5,7 +5,8 @@ import torch
|
||||
from cacheflow import cache_ops
|
||||
|
||||
|
||||
def test_copy_blocks(
|
||||
@torch.inference_mode()
|
||||
def run_copy_blocks(
|
||||
num_mappings: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
@ -60,7 +61,8 @@ def test_copy_blocks(
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
def test_reshape_and_cache(
|
||||
@torch.inference_mode()
|
||||
def run_reshape_and_cache(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -99,7 +101,8 @@ def test_reshape_and_cache(
|
||||
assert torch.allclose(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
def test_gather_cached_kv(
|
||||
@torch.inference_mode()
|
||||
def run_gather_cached_kv(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -140,19 +143,22 @@ def test_gather_cached_kv(
|
||||
assert torch.allclose(value, cloned_value)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_cache() -> None:
|
||||
def test_copy_blocks() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
test_copy_blocks(
|
||||
run_copy_blocks(
|
||||
num_mappings=23, num_layers=7, num_heads=17, head_size=16,
|
||||
block_size=8, num_blocks=1024, dtype=dtype)
|
||||
test_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
test_gather_cached_kv(
|
||||
|
||||
|
||||
def test_reshape_and_cache() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_reshape_and_cache(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cache()
|
||||
def test_gather_cached_kv() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
run_gather_cached_kv(
|
||||
num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2,
|
||||
dtype=dtype)
|
||||
@ -22,7 +22,7 @@ class RefRMSNorm(nn.Module):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
def run_rms_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
@ -41,13 +41,13 @@ def test_rms_norm(
|
||||
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def test_rms_norm() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for num_tokens in [7, 128, 2048]:
|
||||
for hidden_size in [13, 64, 1024, 5120]:
|
||||
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
|
||||
f'{num_tokens}, hidden_size={hidden_size}')
|
||||
test_rms_norm(
|
||||
run_rms_norm(
|
||||
num_tokens=num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
@ -76,7 +76,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_neox(
|
||||
def run_rotary_embedding_neox(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -128,15 +128,15 @@ def test_rotary_embedding_neox(
|
||||
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def test_rotary_embedding_neox() -> None:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Running tests for head_size={head_size} and dtype={dtype}')
|
||||
test_rotary_embedding_neox(
|
||||
run_rotary_embedding_neox(
|
||||
num_tokens=2145,
|
||||
num_heads=5,
|
||||
head_size=head_size,
|
||||
max_position=8192,
|
||||
rotary_dim=int(head_size * 0.25),
|
||||
rotary_dim=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user