Clean up kernel unit tests (#938)

This commit is contained in:
Woosuk Kwon 2023-09-06 08:57:38 +09:00 committed by GitHub
parent 22379d5513
commit fbd80ad409
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 368 additions and 403 deletions

43
tests/kernels/conftest.py Normal file
View File

@ -0,0 +1,43 @@
from typing import List, Tuple
import pytest
import torch
def create_kv_caches(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
seed: int,
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device='cuda')
key_cache.uniform_(-scale, scale)
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device='cuda')
value_cache.uniform_(-scale, scale)
value_caches.append(value_cache)
return key_caches, value_caches
@pytest.fixture()
def kv_cache_factory():
return create_kv_caches

View File

@ -1,20 +1,34 @@
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers.activations import get_activation from transformers.activations import get_activation
from vllm import activation_ops from vllm import activation_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(chunks=2, dim=1) x1, x2 = x.chunk(chunks=2, dim=1)
return F.silu(x1) * x2 return F.silu(x1) * x2
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_silu_and_mul( def test_silu_and_mul(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.silu_and_mul(out, x) activation_ops.silu_and_mul(out, x)
@ -22,20 +36,19 @@ def run_silu_and_mul(
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_silu_and_mul() -> None: @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
for dtype in [torch.half, torch.bfloat16, torch.float]: @pytest.mark.parametrize("d", D)
for num_tokens in [7, 83, 2048]: @pytest.mark.parametrize("dtype", DTYPES)
for d in [512, 4096, 5120, 13824]: @pytest.mark.parametrize("seed", SEEDS)
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_silu_and_mul(num_tokens, d, dtype)
@torch.inference_mode() @torch.inference_mode()
def run_gelu_new( def test_gelu_new(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_new(out, x) activation_ops.gelu_new(out, x)
@ -43,30 +56,20 @@ def run_gelu_new(
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_new() -> None: @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
for dtype in [torch.half, torch.bfloat16, torch.float]: @pytest.mark.parametrize("d", D)
for num_tokens in [7, 83, 2048]: @pytest.mark.parametrize("dtype", DTYPES)
for d in [512, 4096, 5120, 13824]: @pytest.mark.parametrize("seed", SEEDS)
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') def test_gelu_fast(
run_gelu_new(num_tokens, d, dtype)
@torch.inference_mode()
def run_gelu_fast(
num_tokens: int, num_tokens: int,
d: int, d: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda') x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
activation_ops.gelu_fast(out, x) activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x) ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
def test_gelu_fast() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 5120, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
run_gelu_fast(num_tokens, d, dtype)

View File

@ -1,14 +1,24 @@
import random import random
from typing import List, Optional from typing import List, Optional, Tuple
import pytest
import torch import torch
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm import attention_ops from vllm import attention_ops
MAX_SEQ_LEN = 4096 MAX_SEQ_LEN = 8192
TEST_SEED = 0 NUM_BLOCKS = 128 # Arbitrary values for testing
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_GEN_SEQS = [7] # Arbitrary values for testing
NUM_PREFILL_SEQS = [1, 3, 7] # Arbitrary values for testing
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
USE_ALIBI = [False] # TODO(woosuk): Add USE_ALIBI=True
SEEDS = [0]
def ref_masked_attention( def ref_masked_attention(
@ -18,29 +28,34 @@ def ref_masked_attention(
scale: float, scale: float,
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
query = query * scale attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn = torch.einsum('qhd,khd->hqk', query, key)
if attn_mask is not None: if attn_mask is not None:
attn = attn + attn_mask attn_weights = attn_weights + attn_mask.float()
attn = torch.softmax(attn, dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum('hqk,khd->qhd', attn, value) out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out return out
def ref_single_query_cached_kv_attention( def ref_single_query_cached_kv_attention(
output: torch.Tensor, output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
num_queries_per_kv: int,
key_cache: torch.Tensor, key_cache: torch.Tensor,
value_cache: torch.Tensor, value_cache: torch.Tensor,
block_tables: torch.Tensor, block_tables: torch.Tensor,
context_lens: torch.Tensor, context_lens: torch.Tensor,
scale: float,
alibi_slopes: Optional[torch.Tensor],
) -> None: ) -> None:
num_heads = value_cache.shape[1] num_query_heads = query.shape[1]
num_kv_heads = value_cache.shape[1]
head_size = value_cache.shape[2] head_size = value_cache.shape[2]
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs = query.shape[0]
num_input_tokens = query.shape[0] block_tables = block_tables.cpu().tolist()
for i in range(num_input_tokens): context_lens = context_lens.cpu().tolist()
for i in range(num_seqs):
q = query[i].unsqueeze(0) q = query[i].unsqueeze(0)
block_table = block_tables[i] block_table = block_tables[i]
context_len = int(context_lens[i]) context_len = int(context_lens[i])
@ -52,30 +67,138 @@ def ref_single_query_cached_kv_attention(
block_offset = j % block_size block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :] k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size) k = k.reshape(num_kv_heads, head_size)
keys.append(k) keys.append(k)
v = value_cache[block_number, :, :, block_offset] v = value_cache[block_number, :, :, block_offset]
values.append(v) values.append(v)
keys = torch.stack(keys, dim=0) keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values, dim=0)
if num_queries_per_kv > 1:
# Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
scale = 1.0 / (head_size**0.5) alibi_bias = None
out = ref_masked_attention(q, keys, values, scale) if alibi_slopes is not None:
out = out.view(num_heads, head_size) # Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device="cuda").int()
alibi_bias = (context_len - position_ids).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
out = out.view(num_query_heads, head_size)
output[i].copy_(out, non_blocking=True) output[i].copy_(out, non_blocking=True)
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_single_query_cached_kv_attention(
kv_cache_factory,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
seed: int,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size, dtype,
seed)
key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel.
output = torch.empty_like(query)
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)
# Run the reference implementation.
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
context_lens,
scale,
alibi_slopes,
)
# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
cu_seq_lens: List[int], cu_seq_lens: List[int],
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
scale: float,
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size**0.5)
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs = [] ref_outputs = []
for i in range(num_seqs): for i in range(num_seqs):
@ -87,7 +210,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda') attn_mask = attn_mask.to(dtype=dtype, device="cuda")
ref_output = ref_masked_attention( ref_output = ref_masked_attention(
query[start_idx:end_idx], query[start_idx:end_idx],
@ -101,171 +224,42 @@ def ref_multi_query_kv_attention(
return ref_output return ref_output
def ref_multi_query_cached_kv_attention( @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
cu_query_lens: List[int], @pytest.mark.parametrize("num_heads", NUM_HEADS)
query: torch.Tensor, @pytest.mark.parametrize("head_size", HEAD_SIZES)
key_cache: torch.Tensor, @pytest.mark.parametrize("dtype", DTYPES)
value_cache: torch.Tensor, @pytest.mark.parametrize("seed", SEEDS)
block_tables: torch.Tensor,
context_lens: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
scale = 1.0 / (head_size**0.5)
num_queries = len(cu_query_lens) - 1
ref_outputs = []
for i in range(num_queries):
start_idx = cu_query_lens[i]
end_idx = cu_query_lens[i + 1]
query_len = end_idx - start_idx
context_len = int(context_lens[i])
block_table = block_tables[i]
# Create attention mask
attn_mask = torch.triu(torch.ones(query_len, context_len),
diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)
v = value_cache[block_number, :, :, block_offset]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
keys,
values,
scale,
attn_mask=attn_mask,
)
ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output
@torch.inference_mode() @torch.inference_mode()
def run_single_query_cached_kv_attention( def test_multi_query_kv_attention(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
num_kv_heads: int = None,
) -> None:
qkv = torch.empty(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_tokens):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
scale = float(1.0 / (head_size**0.5))
num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert num_heads % num_kv_heads == 0
num_queries_per_kv = num_heads // num_kv_heads
head_mapping = torch.repeat_interleave(
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
num_queries_per_kv)
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
None, # ALiBi slopes.
)
ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
key_cache,
value_cache,
block_tables,
context_lens,
)
# NOTE(woosuk): Due to the difference in the data types the two
# implementations use for attention softmax logits and accumulation,
# there is a small difference in the final outputs.
# We should use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def run_multi_query_kv_attention(
num_seqs: int, num_seqs: int,
num_heads: int, num_heads: Tuple[int, int],
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens) num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens, qkv = torch.empty(num_tokens,
3, num_query_heads + 2 * num_kv_heads,
num_heads,
head_size, head_size,
dtype=dtype, dtype=dtype,
device='cuda') device="cuda")
qkv.uniform_(-1e-3, 1e-3) qkv.uniform_(-scale, scale)
query, key, value = qkv.unbind(dim=1) query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward( output = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
@ -285,40 +279,7 @@ def run_multi_query_kv_attention(
query, query,
key, key,
value, value,
scale,
dtype, dtype,
) )
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
run_single_query_cached_kv_attention(
num_tokens=37,
num_heads=3,
head_size=head_size,
block_size=block_size,
num_blocks=1024,
dtype=dtype,
)
def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
run_multi_query_kv_attention(
num_seqs=5,
num_heads=3,
head_size=head_size,
dtype=dtype,
)

View File

@ -1,12 +1,32 @@
import random import random
import pytest
import torch import torch
from vllm import cache_ops from vllm import cache_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
NUM_LAYERS = [5] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024] # Arbitrary values for testing
NUM_MAPPINGS = [32, 256] # Arbitrary values for testing
SEEDS = [0]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_copy_blocks( def test_copy_blocks(
kv_cache_factory,
num_mappings: int, num_mappings: int,
num_layers: int, num_layers: int,
num_heads: int, num_heads: int,
@ -14,48 +34,43 @@ def run_copy_blocks(
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
# Generate random block mappings. random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {src: [dst] for src, dst in zip(src_blocks, dst_blocks)} block_mapping = {}
for i in range(num_mappings):
src = src_blocks[i]
dst1 = dst_blocks[2 * i]
dst2 = dst_blocks[2 * i + 1]
block_mapping[src] = [dst1, dst2]
# Create the KV cache. # Create the KV caches.
x = 16 // torch.tensor([], dtype=dtype).element_size() key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) num_layers, num_heads,
key_caches = [] head_size, dtype, seed)
for _ in range(num_layers):
key_cache = torch.randn(size=key_cache_shape,
dtype=dtype,
device='cuda')
key_caches.append(key_cache)
cloned_key_caches = []
for key_cache in key_caches:
cloned_key_caches.append(key_cache.clone())
value_cache_shape = (num_blocks, num_heads, head_size, block_size) # Clone the KV caches.
value_caches = [] cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
for _ in range(num_layers): cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
value_cache = torch.randn(size=value_cache_shape,
dtype=dtype,
device='cuda')
value_caches.append(value_cache)
cloned_value_caches = []
for value_cache in value_caches:
cloned_value_caches.append(value_cache.clone())
# Call the copy blocks kernel. # Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping) cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
# Reference implementation. # Run the reference implementation.
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
for dst in dsts: for dst in dsts:
for key_cache, cloned_key_cache in zip(key_caches, for cloned_key_cache in cloned_key_caches:
cloned_key_caches):
cloned_key_cache[dst] = cloned_key_cache[src] cloned_key_cache[dst] = cloned_key_cache[src]
for value_cache, cloned_value_cache in zip(value_caches, for cloned_value_cache in cloned_value_caches:
cloned_value_caches):
cloned_value_cache[dst] = cloned_value_cache[src] cloned_value_cache[dst] = cloned_value_cache[src]
# Compare the results. # Compare the results.
@ -66,15 +81,29 @@ def run_copy_blocks(
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_reshape_and_cache( def test_reshape_and_cache(
kv_cache_factory,
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
# Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
@ -87,110 +116,31 @@ def run_reshape_and_cache(
device='cuda') device='cuda')
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size() # Create the KV caches.
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') num_heads, head_size, dtype,
cloned_key_cache = key_cache.clone() seed)
key_cache, value_cache = key_caches[0], value_caches[0]
value_cache_shape = (num_blocks, num_heads, head_size, block_size) # Clone the KV caches.
value_cache = torch.randn(size=value_cache_shape, cloned_key_cache = key_cache.clone()
dtype=dtype,
device='cuda')
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
# Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping) slot_mapping)
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()
for i in range(num_tokens): for i in range(num_tokens):
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) block_idx = block_indicies[i]
block_idx = torch.div(slot_mapping[i], block_offset = block_offsets[i]
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache) assert torch.allclose(value_cache, cloned_value_cache)
@torch.inference_mode()
def run_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
_, key, value = qkv.unbind(dim=1)
qkv_clone = qkv.clone()
_, cloned_key, cloned_value = qkv_clone.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(size=value_cache_shape,
dtype=dtype,
device='cuda')
cache_ops.gather_cached_kv(key, value, key_cache, value_cache,
slot_mapping)
# Reference implementation.
for i in range(num_tokens):
reshaped_key = cloned_key.reshape(num_tokens, num_heads,
head_size // x, x)
block_idx = torch.div(slot_mapping[i],
block_size,
rounding_mode='floor')
block_offset = slot_mapping[i] % block_size
reshaped_key[i] = key_cache[block_idx, :, :, block_offset, :]
cloned_value[i] = value_cache[block_idx, :, :, block_offset]
assert torch.allclose(key, cloned_key)
assert torch.allclose(value, cloned_value)
def test_copy_blocks() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_copy_blocks(num_mappings=23,
num_layers=7,
num_heads=17,
head_size=16,
block_size=8,
num_blocks=1024,
dtype=dtype)
def test_reshape_and_cache() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_reshape_and_cache(num_tokens=3,
num_heads=2,
head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)
def test_gather_cached_kv() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
run_gather_cached_kv(num_tokens=3,
num_heads=2,
head_size=16,
block_size=8,
num_blocks=2,
dtype=dtype)

View File

@ -1,35 +1,50 @@
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import layernorm_ops from vllm import layernorm_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
class RefRMSNorm(nn.Module): class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
weight = torch.empty(hidden_size) weight = torch.empty(hidden_size)
weight.uniform_(-1e-3, 1e-3) weight.normal_(mean=1.0, std=0.1)
self.weight = nn.Parameter(weight) self.weight = nn.Parameter(weight)
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states): def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, input_dtype = hidden_states.dtype
keepdim=True) hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon) self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]: return self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_rms_norm( def test_rms_norm(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
) -> None: ) -> None:
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda') torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
scale = float(hidden_size**-0.5)
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")
x.uniform_(-scale, scale)
ref = RefRMSNorm(hidden_size).to(dtype).cuda() ref = RefRMSNorm(hidden_size).to(dtype).cuda()
out = torch.empty_like(x) out = torch.empty_like(x)
@ -40,17 +55,4 @@ def run_rms_norm(
ref.variance_epsilon, ref.variance_epsilon,
) )
ref_out = ref(x) ref_out = ref(x)
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5)
def test_rms_norm() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
run_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,
)

View File

@ -1,11 +1,19 @@
from typing import Tuple from typing import Optional, Tuple
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm import pos_encoding_ops from vllm import pos_encoding_ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
SEEDS = [0]
def rotate_half(x: torch.Tensor) -> torch.Tensor: def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2] x1 = x[..., :x.shape[-1] // 2]
@ -74,16 +82,28 @@ class RefRotaryEmbeddingNeox(nn.Module):
return query, key return query, key
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def run_rotary_embedding_neox( def test_rotary_embedding_neox(
num_tokens: int, num_tokens: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
max_position: int, rotary_dim: Optional[int],
rotary_dim: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int,
max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') positions = torch.randint(0, max_position, (num_tokens, ), device='cuda')
query = torch.randn(num_tokens, query = torch.randn(num_tokens,
num_heads * head_size, num_heads * head_size,
@ -97,7 +117,7 @@ def run_rotary_embedding_neox(
# Create the rotary embedding. # Create the rotary embedding.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float() t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = torch.cat((cos, sin), dim=-1)
@ -129,19 +149,5 @@ def run_rotary_embedding_neox(
ref_key = ref_key.view(num_tokens, num_heads * head_size) ref_key = ref_key.view(num_tokens, num_heads * head_size)
# Compare the results. # Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
def test_rotary_embedding_neox() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Running tests for head_size={head_size} and dtype={dtype}')
run_rotary_embedding_neox(
num_tokens=2145,
num_heads=5,
head_size=head_size,
max_position=8192,
rotary_dim=head_size,
dtype=dtype,
)