mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 04:26:41 +08:00
Optimize data movement (#20)
This commit is contained in:
parent
1f01a18d39
commit
897cb2ae28
20
cacheflow/models/activation.py
Normal file
20
cacheflow/models/activation.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from cacheflow import activation_ops
|
||||||
|
|
||||||
|
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor, # (num_tokens, 2 * d)
|
||||||
|
) -> torch.Tensor: # (num_tokens, d)
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
d = x.shape[1] // 2
|
||||||
|
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
|
||||||
|
activation_ops.silu_and_mul(out, x)
|
||||||
|
return out
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
|
||||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
|
||||||
|
|
||||||
def multi_query_kv_attention(
|
def multi_query_kv_attention(
|
||||||
self,
|
self,
|
||||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||||
prompt_lens: List[int],
|
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
|
||||||
|
max_prompt_len: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
if query.dtype == torch.float:
|
if query.dtype == torch.float:
|
||||||
raise ValueError('The float data type is not supported by '
|
raise ValueError('The float data type is not supported by '
|
||||||
'FlashAttention. Use the half data type instead.')
|
'FlashAttention. Use the half data type instead.')
|
||||||
head_size = query.shape[2]
|
head_size = query.shape[-1]
|
||||||
if head_size > 128:
|
if head_size > 128:
|
||||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||||
|
|
||||||
device = query.device
|
# Directly call FlashAttention's internal function to avoid allocating
|
||||||
prefix_sum = [0]
|
# a new tensor for the output.
|
||||||
for prompt_len in prompt_lens:
|
_flash_attn_forward(
|
||||||
prefix_sum.append(prefix_sum[-1] + prompt_len)
|
query,
|
||||||
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
|
key,
|
||||||
max_prompt_len = max(prompt_lens)
|
value,
|
||||||
|
output,
|
||||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
cumulative_prompt_lens,
|
||||||
qkv = torch.stack([query, key, value], dim=1)
|
cumulative_prompt_lens,
|
||||||
out = self.flash_attn(
|
max_prompt_len,
|
||||||
qkv,
|
max_prompt_len,
|
||||||
cu_seqlens=prefix_sum,
|
dropout_p=0.0,
|
||||||
max_s=max_prompt_len,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
)[0]
|
return_softmax=False,
|
||||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
)
|
||||||
output.copy_(out, non_blocking=True)
|
|
||||||
|
|
||||||
def single_query_cached_kv_attention(
|
def single_query_cached_kv_attention(
|
||||||
self,
|
self,
|
||||||
@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||||
# Pre-allocate the output tensor.
|
# NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||||
output = torch.empty_like(query)
|
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||||
|
|
||||||
# Prune out paddings if any.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query[:input_metadata.num_valid_tokens]
|
|
||||||
key = key[:input_metadata.num_valid_tokens]
|
|
||||||
value = value[:input_metadata.num_valid_tokens]
|
|
||||||
|
|
||||||
# Reshape the input tensors.
|
|
||||||
num_heads = value_cache.shape[1]
|
num_heads = value_cache.shape[1]
|
||||||
head_size = value_cache.shape[2]
|
head_size = value_cache.shape[2]
|
||||||
query = query.view(-1, num_heads, head_size)
|
query = query.view(-1, num_heads, head_size)
|
||||||
key = key.view(-1, num_heads, head_size)
|
key = key.view(-1, num_heads, head_size)
|
||||||
value = value.view(-1, num_heads, head_size)
|
value = value.view(-1, num_heads, head_size)
|
||||||
output = output.view(-1, num_heads, head_size)
|
|
||||||
|
# Pre-allocate the output tensor.
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
# Compute the attention op for prompts.
|
# Compute the attention op for prompts.
|
||||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||||
@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
query[:num_prompt_tokens],
|
query[:num_prompt_tokens],
|
||||||
key[:num_prompt_tokens],
|
key[:num_prompt_tokens],
|
||||||
value[:num_prompt_tokens],
|
value[:num_prompt_tokens],
|
||||||
input_metadata.prompt_lens,
|
input_metadata.cumulative_prompt_lens,
|
||||||
|
input_metadata.max_prompt_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait until the cache op is done.
|
# Wait until the cache op is done.
|
||||||
@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module):
|
|||||||
cache_event.wait()
|
cache_event.wait()
|
||||||
|
|
||||||
# Reshape the keys and values and store them in the cache.
|
# Reshape the keys and values and store them in the cache.
|
||||||
cache_ops.reshape_and_cache(
|
num_valid_tokens = input_metadata.num_valid_tokens
|
||||||
key, value, key_cache, value_cache, input_metadata.slot_mapping)
|
if num_valid_tokens > 0:
|
||||||
|
# The stride is 3 because the key and value are sliced from qkv.
|
||||||
|
cache_ops.reshape_and_cache(
|
||||||
|
key[:num_valid_tokens],
|
||||||
|
value[:num_valid_tokens],
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
input_metadata.slot_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
# Compute the attention op for generation tokens.
|
# Compute the attention op for generation tokens.
|
||||||
self.single_query_cached_kv_attention(
|
self.single_query_cached_kv_attention(
|
||||||
output[num_prompt_tokens:],
|
output[num_prompt_tokens:num_valid_tokens],
|
||||||
query[num_prompt_tokens:],
|
query[num_prompt_tokens:num_valid_tokens],
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention):
|
|||||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
out_query = torch.empty_like(query)
|
|
||||||
out_key = torch.empty_like(key)
|
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding_neox(
|
||||||
out_query,
|
|
||||||
out_key,
|
|
||||||
positions,
|
positions,
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
self.cos_sin_cache,
|
self.cos_sin_cache,
|
||||||
)
|
)
|
||||||
return super().forward(
|
return super().forward(
|
||||||
out_query,
|
query,
|
||||||
out_key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ class InputMetadata:
|
|||||||
seq_groups: List[Tuple[List[int], SamplingParams]],
|
seq_groups: List[Tuple[List[int], SamplingParams]],
|
||||||
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
|
||||||
prompt_lens: List[int],
|
prompt_lens: List[int],
|
||||||
|
cumulative_prompt_lens: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
@ -20,6 +21,7 @@ class InputMetadata:
|
|||||||
self.seq_groups = seq_groups
|
self.seq_groups = seq_groups
|
||||||
self.seq_logprobs = seq_logprobs
|
self.seq_logprobs = seq_logprobs
|
||||||
self.prompt_lens = prompt_lens
|
self.prompt_lens = prompt_lens
|
||||||
|
self.cumulative_prompt_lens = cumulative_prompt_lens
|
||||||
self.slot_mapping = slot_mapping
|
self.slot_mapping = slot_mapping
|
||||||
self.context_lens = context_lens
|
self.context_lens = context_lens
|
||||||
self.max_context_len = max_context_len
|
self.max_context_len = max_context_len
|
||||||
@ -27,6 +29,7 @@ class InputMetadata:
|
|||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
self.num_prompts = len(prompt_lens)
|
||||||
self.num_prompt_tokens = sum(prompt_lens)
|
self.num_prompt_tokens = sum(prompt_lens)
|
||||||
|
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
self.num_generation_tokens = context_lens.shape[0]
|
||||||
self.num_valid_tokens = slot_mapping.shape[0]
|
self.num_valid_tokens = slot_mapping.shape[0]
|
||||||
if block_tables.numel() > 0:
|
if block_tables.numel() > 0:
|
||||||
@ -40,11 +43,13 @@ class InputMetadata:
|
|||||||
return (f'InputMetadata('
|
return (f'InputMetadata('
|
||||||
f'num_prompts={self.num_prompts}, '
|
f'num_prompts={self.num_prompts}, '
|
||||||
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
f'num_prompt_tokens={self.num_prompt_tokens}, '
|
||||||
|
f'max_prompt_len={self.max_prompt_len}, '
|
||||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||||
f'max_context_len={self.max_context_len}), '
|
f'max_context_len={self.max_context_len}), '
|
||||||
f'prompt_lens={self.prompt_lens}, '
|
f'prompt_lens={self.prompt_lens}, '
|
||||||
|
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
|
||||||
f'slot_mapping={self.slot_mapping}, '
|
f'slot_mapping={self.slot_mapping}, '
|
||||||
f'context_lens={self.context_lens}, '
|
f'context_lens={self.context_lens}, '
|
||||||
f'block_tables={self.block_tables})')
|
f'block_tables={self.block_tables})')
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from cacheflow.models import InputMetadata
|
from cacheflow.models import InputMetadata
|
||||||
|
from cacheflow.models.activation import SiluAndMul
|
||||||
from cacheflow.models.attention import LlamaCacheFlowAttention
|
from cacheflow.models.attention import LlamaCacheFlowAttention
|
||||||
from cacheflow.models.layernorm import RMSNorm
|
from cacheflow.models.layernorm import RMSNorm
|
||||||
from cacheflow.models.sample import Sampler
|
from cacheflow.models.sample import Sampler
|
||||||
@ -39,16 +40,14 @@ class LlamaMLP(nn.Module):
|
|||||||
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||||
bias=False, input_is_parallel=True,
|
bias=False, input_is_parallel=True,
|
||||||
perform_initialization=False)
|
perform_initialization=False)
|
||||||
assert hidden_act == 'silu'
|
if hidden_act != 'silu':
|
||||||
self.act_fn = nn.SiLU()
|
raise ValueError(f'Unsupported activation: {hidden_act}. '
|
||||||
|
'Only silu is supported for now.')
|
||||||
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
|
x = self.act_fn(gate_up)
|
||||||
gate, up = torch.split(gate_up, 1, dim=-2)
|
|
||||||
gate = gate.squeeze(dim=-2).contiguous()
|
|
||||||
up = up.squeeze(dim=-2).contiguous()
|
|
||||||
x = self.act_fn(gate) * up
|
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -94,11 +93,7 @@ class LlamaAttention(nn.Module):
|
|||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
|
||||||
q = q.squeeze(dim=-2).contiguous()
|
|
||||||
k = k.squeeze(dim=-2).contiguous()
|
|
||||||
v = v.squeeze(dim=-2).contiguous()
|
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
|
||||||
|
|||||||
@ -69,17 +69,14 @@ class OPTAttention(nn.Module):
|
|||||||
cache_event: Optional[torch.cuda.Event],
|
cache_event: Optional[torch.cuda.Event],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||||
q, k, v = torch.split(qkv, 1, dim=-2)
|
|
||||||
q = q.squeeze(dim=-2).contiguous()
|
|
||||||
k = k.squeeze(dim=-2).contiguous()
|
|
||||||
v = v.squeeze(dim=-2).contiguous()
|
|
||||||
key_cache, value_cache = kv_cache
|
key_cache, value_cache = kv_cache
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||||
output, _ = self.out_proj(attn_output)
|
output, _ = self.out_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class OPTDecoderLayer(nn.Module):
|
class OPTDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: OPTConfig):
|
def __init__(self, config: OPTConfig):
|
||||||
|
|||||||
@ -128,6 +128,11 @@ class Worker:
|
|||||||
slot = block_number * self.block_size + block_offset
|
slot = block_number * self.block_size + block_offset
|
||||||
slot_mapping.append(slot)
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
cumulative_prompt_lens: List[int] = [0]
|
||||||
|
for prompt_len in prompt_lens:
|
||||||
|
cumulative_prompt_lens.append(
|
||||||
|
cumulative_prompt_lens[-1] + prompt_len)
|
||||||
|
|
||||||
# Add generation tokens.
|
# Add generation tokens.
|
||||||
max_context_len = 0
|
max_context_len = 0
|
||||||
max_num_blocks_per_seq = 0
|
max_num_blocks_per_seq = 0
|
||||||
@ -183,11 +188,14 @@ class Worker:
|
|||||||
for block_table in generation_block_tables]
|
for block_table in generation_block_tables]
|
||||||
block_tables_tensor = torch.tensor(
|
block_tables_tensor = torch.tensor(
|
||||||
padded_block_tables, dtype=torch.int, device='cuda')
|
padded_block_tables, dtype=torch.int, device='cuda')
|
||||||
|
cumulative_prompt_lens_tensor = torch.tensor(
|
||||||
|
cumulative_prompt_lens, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
seq_groups=seq_groups,
|
seq_groups=seq_groups,
|
||||||
seq_logprobs=seq_logprobs,
|
seq_logprobs=seq_logprobs,
|
||||||
prompt_lens=prompt_lens,
|
prompt_lens=prompt_lens,
|
||||||
|
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
context_lens=context_lens_tensor,
|
context_lens=context_lens_tensor,
|
||||||
max_context_len=max_context_len,
|
max_context_len=max_context_len,
|
||||||
|
|||||||
12
csrc/activation.cpp
Normal file
12
csrc/activation.cpp
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void silu_and_mul(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"silu_and_mul",
|
||||||
|
&silu_and_mul,
|
||||||
|
"Activation function used in SwiGLU.");
|
||||||
|
}
|
||||||
46
csrc/activation_kernels.cu
Normal file
46
csrc/activation_kernels.cu
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T silu(const T& x) {
|
||||||
|
// x * sigmoid(x)
|
||||||
|
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
__global__ void silu_and_mul_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||||
|
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
|
||||||
|
const int d) {
|
||||||
|
const int token_idx = blockIdx.x;
|
||||||
|
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||||
|
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||||
|
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||||
|
out[token_idx * d + idx] = silu(x) * y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
|
|
||||||
|
void silu_and_mul(
|
||||||
|
torch::Tensor& out, // [num_tokens, d]
|
||||||
|
torch::Tensor& input) // [num_tokens, 2 * d]
|
||||||
|
{
|
||||||
|
int num_tokens = input.size(0);
|
||||||
|
int d = input.size(1) / 2;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(d, 1024));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
|
input.scalar_type(),
|
||||||
|
"silu_and_mul_kernel",
|
||||||
|
[&] {
|
||||||
|
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
d);
|
||||||
|
});
|
||||||
|
}
|
||||||
@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
const float scale,
|
const float scale,
|
||||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
const int max_num_blocks_per_seq) {
|
const int max_num_blocks_per_seq,
|
||||||
|
const int q_stride) {
|
||||||
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
|
||||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||||
const int thread_idx = threadIdx.x;
|
const int thread_idx = threadIdx.x;
|
||||||
@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
// For example, if the the thread group size is 4, then the first thread in the group
|
// For example, if the the thread group size is 4, then the first thread in the group
|
||||||
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
||||||
// th vectors of the query, and so on.
|
// th vectors of the query, and so on.
|
||||||
const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||||
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||||
@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||||||
scale, \
|
scale, \
|
||||||
block_tables_ptr, \
|
block_tables_ptr, \
|
||||||
context_lens_ptr, \
|
context_lens_ptr, \
|
||||||
max_num_blocks_per_seq);
|
max_num_blocks_per_seq, \
|
||||||
|
query_stride);
|
||||||
|
|
||||||
// TODO(woosuk): Tune NUM_THREADS.
|
// TODO(woosuk): Tune NUM_THREADS.
|
||||||
template<
|
template<
|
||||||
@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
int max_num_blocks_per_seq = block_tables.size(1);
|
int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
int query_stride = query.stride(0);
|
||||||
|
|
||||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||||
@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void single_query_cached_kv_attention(
|
void single_query_cached_kv_attention(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& query,
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
float scale,
|
float scale,
|
||||||
torch::Tensor& block_tables,
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
torch::Tensor& context_lens,
|
torch::Tensor& context_lens, // [num_seqs]
|
||||||
int block_size,
|
int block_size,
|
||||||
int max_context_len) {
|
int max_context_len) {
|
||||||
// TODO(woosuk): Support BF16.
|
// TODO(woosuk): Support BF16.
|
||||||
|
|||||||
@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||||
|
const int key_stride,
|
||||||
|
const int value_stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int head_size,
|
const int head_size,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
|
|
||||||
const int n = num_heads * head_size;
|
const int n = num_heads * head_size;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||||
const int src_idx = token_idx * n + i;
|
const int src_key_idx = token_idx * key_stride + i;
|
||||||
|
const int src_value_idx = token_idx * value_stride + i;
|
||||||
|
|
||||||
const int head_idx = i / head_size;
|
const int head_idx = i / head_size;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_size;
|
||||||
@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
+ head_idx * head_size * block_size
|
+ head_idx * head_size * block_size
|
||||||
+ head_offset * block_size
|
+ head_offset * block_size
|
||||||
+ block_offset;
|
+ block_offset;
|
||||||
key_cache[tgt_key_idx] = __ldg(&key[src_idx]);
|
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
|
||||||
value_cache[tgt_value_idx] = __ldg(&value[src_idx]);
|
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace cacheflow
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key,
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& value,
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||||
torch::Tensor& value_cache,
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping) {
|
torch::Tensor& slot_mapping) // [num_tokens]
|
||||||
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
int block_size = key_cache.size(3);
|
int block_size = key_cache.size(3);
|
||||||
int x = key_cache.size(4);
|
int x = key_cache.size(4);
|
||||||
|
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int value_stride = value.stride(0);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
@ -140,6 +147,8 @@ void reshape_and_cache(
|
|||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(),
|
||||||
value_cache.data_ptr<scalar_t>(),
|
value_cache.data_ptr<scalar_t>(),
|
||||||
slot_mapping.data_ptr<int>(),
|
slot_mapping.data_ptr<int>(),
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding_neox(
|
||||||
torch::Tensor& out_query,
|
|
||||||
torch::Tensor& out_key,
|
|
||||||
torch::Tensor& positions,
|
torch::Tensor& positions,
|
||||||
torch::Tensor& query,
|
torch::Tensor& query,
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
|
|||||||
@ -5,12 +5,11 @@ namespace cacheflow {
|
|||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void rotary_embedding_neox_kernel(
|
__global__ void rotary_embedding_neox_kernel(
|
||||||
scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size]
|
|
||||||
scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size]
|
|
||||||
const int64_t* __restrict__ positions, // [num_tokens]
|
const int64_t* __restrict__ positions, // [num_tokens]
|
||||||
const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
|
||||||
|
const int stride,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int head_size) {
|
const int head_size) {
|
||||||
// Each thread block is responsible for one token.
|
// Each thread block is responsible for one token.
|
||||||
@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel(
|
|||||||
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
|
||||||
|
|
||||||
const int embed_dim = head_size / 2;
|
const int embed_dim = head_size / 2;
|
||||||
const int n = num_heads * head_size;
|
const int n = num_heads * embed_dim;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||||
const int idx = token_idx * n + i;
|
const int head_idx = i / embed_dim;
|
||||||
|
const int token_head = token_idx * stride + head_idx * head_size;
|
||||||
|
|
||||||
const int head_idx = i / head_size;
|
const int rot_offset = i % embed_dim;
|
||||||
const int head_offset = i % head_size;
|
|
||||||
const int token_head = token_idx * n + head_idx * head_size;
|
|
||||||
|
|
||||||
const bool is_first_half = head_offset < embed_dim;
|
|
||||||
const int rot_offset = head_offset % embed_dim;
|
|
||||||
const int x_index = rot_offset;
|
const int x_index = rot_offset;
|
||||||
const int y_index = embed_dim + rot_offset;
|
const int y_index = embed_dim + rot_offset;
|
||||||
|
|
||||||
|
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
||||||
|
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
||||||
|
|
||||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||||
|
|
||||||
const scalar_t q_x = __ldg(query + token_head + x_index);
|
const scalar_t q_x = query[token_head + x_index];
|
||||||
const scalar_t q_y = __ldg(query + token_head + y_index);
|
const scalar_t q_y = query[token_head + y_index];
|
||||||
const scalar_t q_cos = is_first_half ? q_x : q_y;
|
query[out_x] = q_x * cos - q_y * sin;
|
||||||
const scalar_t q_sin = is_first_half ? -q_y : q_x;
|
query[out_y] = q_y * cos + q_x * sin;
|
||||||
out_query[idx] = q_cos * cos + q_sin * sin;
|
|
||||||
|
|
||||||
const scalar_t k_x = __ldg(key + token_head + x_index);
|
const scalar_t k_x = key[token_head + x_index];
|
||||||
const scalar_t k_y = __ldg(key + token_head + y_index);
|
const scalar_t k_y = key[token_head + y_index];
|
||||||
const scalar_t k_cos = is_first_half ? k_x : k_y;
|
key[out_x] = k_x * cos - k_y * sin;
|
||||||
const scalar_t k_sin = is_first_half ? -k_y : k_x;
|
key[out_y] = k_y * cos + k_x * sin;
|
||||||
out_key[idx] = k_cos * cos + k_sin * sin;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace cacheflow
|
} // namespace cacheflow
|
||||||
|
|
||||||
void rotary_embedding_neox(
|
void rotary_embedding_neox(
|
||||||
torch::Tensor& out_query, // [num_tokens, num_heads * head_size]
|
|
||||||
torch::Tensor& out_key, // [num_tokens, num_heads * head_size]
|
|
||||||
torch::Tensor& positions, // [num_tokens]
|
torch::Tensor& positions, // [num_tokens]
|
||||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||||
@ -62,21 +56,22 @@ void rotary_embedding_neox(
|
|||||||
int num_tokens = query.size(0);
|
int num_tokens = query.size(0);
|
||||||
int head_size = cos_sin_cache.size(1);
|
int head_size = cos_sin_cache.size(1);
|
||||||
int num_heads = query.size(1) / head_size;
|
int num_heads = query.size(1) / head_size;
|
||||||
|
int stride = query.stride(0);
|
||||||
|
TORCH_CHECK(stride == key.stride(0));
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_size / 2, 512));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||||
query.scalar_type(),
|
query.scalar_type(),
|
||||||
"rotary_embedding_neox",
|
"rotary_embedding_neox",
|
||||||
[&] {
|
[&] {
|
||||||
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out_query.data_ptr<scalar_t>(),
|
|
||||||
out_key.data_ptr<scalar_t>(),
|
|
||||||
positions.data_ptr<int64_t>(),
|
positions.data_ptr<int64_t>(),
|
||||||
query.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
cos_sin_cache.data_ptr<scalar_t>(),
|
cos_sin_cache.data_ptr<scalar_t>(),
|
||||||
|
stride,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_size);
|
head_size);
|
||||||
});
|
});
|
||||||
|
|||||||
7
setup.py
7
setup.py
@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension(
|
|||||||
)
|
)
|
||||||
ext_modules.append(layernorm_extension)
|
ext_modules.append(layernorm_extension)
|
||||||
|
|
||||||
|
activation_extension = cpp_extension.CUDAExtension(
|
||||||
|
name='cacheflow.activation_ops',
|
||||||
|
sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'],
|
||||||
|
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
|
||||||
|
)
|
||||||
|
ext_modules.append(activation_extension)
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name='cacheflow',
|
name='cacheflow',
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
|
|||||||
30
tests/kernels/activation.py
Normal file
30
tests/kernels/activation.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from cacheflow import activation_ops
|
||||||
|
|
||||||
|
|
||||||
|
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||||
|
return F.silu(x1) * x2
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_silu_and_mul(
|
||||||
|
num_tokens: int,
|
||||||
|
d: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> None:
|
||||||
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||||
|
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||||
|
activation_ops.silu_and_mul(out, x)
|
||||||
|
ref_out = ref_silu_and_mul(x)
|
||||||
|
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
for dtype in [torch.half, torch.float]:
|
||||||
|
for num_tokens in [7, 83, 2048]:
|
||||||
|
for d in [512, 4096, 13824]:
|
||||||
|
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||||
|
test_silu_and_mul(num_tokens, d, dtype)
|
||||||
@ -1,7 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from flash_attn.flash_attention import FlashAttention
|
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow import attention_ops
|
from cacheflow import attention_ops
|
||||||
@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
|
|||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> None:
|
) -> None:
|
||||||
query = torch.randn(
|
qkv = torch.randn(
|
||||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
|
query, _, _ = qkv.unbind(dim=1)
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||||
key_cache = torch.randn(
|
key_cache = torch.randn(
|
||||||
@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
|
|||||||
value_cache = torch.randn(
|
value_cache = torch.randn(
|
||||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||||
|
|
||||||
|
# Adjust the range of the values to reduce precision errors.
|
||||||
|
query = query / (head_size ** 0.5)
|
||||||
|
key_cache = key_cache / (head_size ** 0.5)
|
||||||
|
value_cache = value_cache / (head_size ** 0.5)
|
||||||
|
|
||||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||||
max_context_len = max(context_lens)
|
max_context_len = max(context_lens)
|
||||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||||
@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
|
|||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
scale = float(1.0 / (head_size ** 0.5))
|
scale = float(1.0 / (head_size ** 0.5))
|
||||||
output = torch.empty_like(query)
|
output = torch.empty(
|
||||||
|
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
attention_ops.single_query_cached_kv_attention(
|
attention_ops.single_query_cached_kv_attention(
|
||||||
output,
|
output,
|
||||||
query,
|
query,
|
||||||
@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
|
|||||||
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
scale = float(1.0 / (head_size ** 0.5))
|
scale = float(1.0 / (head_size ** 0.5))
|
||||||
query = torch.randn(
|
qkv = torch.randn(
|
||||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
key = torch.rand_like(query)
|
# Adjust the range of the values to reduce precision errors.
|
||||||
value = torch.rand_like(query)
|
qkv = qkv / (head_size ** 0.5)
|
||||||
|
|
||||||
qkv = torch.stack([query, key, value], dim=1)
|
query, key, value = qkv.unbind(dim=1)
|
||||||
flash_attn = FlashAttention(softmax_scale=scale)
|
output = torch.empty(
|
||||||
output = flash_attn(
|
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
qkv,
|
_flash_attn_forward(
|
||||||
cu_seqlens=cu_seq_lens,
|
query,
|
||||||
max_s=max_seq_len,
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
cu_seq_lens,
|
||||||
|
cu_seq_lens,
|
||||||
|
max_seq_len,
|
||||||
|
max_seq_len,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
)[0]
|
return_softmax=False,
|
||||||
|
)
|
||||||
|
|
||||||
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||||
ref_output = ref_multi_query_kv_attention(
|
ref_output = ref_multi_query_kv_attention(
|
||||||
|
|||||||
@ -17,10 +17,10 @@ def test_reshape_and_cache(
|
|||||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||||
|
|
||||||
kv_shape = (num_tokens, num_heads, head_size)
|
qkv = torch.randn(
|
||||||
key = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
|
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||||
value = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
|
_, key, value = qkv.unbind(dim=1)
|
||||||
|
|
||||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||||
@ -35,7 +35,7 @@ def test_reshape_and_cache(
|
|||||||
|
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||||
block_idx = slot_mapping[i] // block_size
|
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||||
block_offset = slot_mapping[i] % block_size
|
block_offset = slot_mapping[i] % block_size
|
||||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||||
|
|||||||
@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
|
|||||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||||
|
|
||||||
# Run the kernel.
|
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||||
out_query = torch.empty_like(query)
|
out_query = query.clone()
|
||||||
out_key = torch.empty_like(key)
|
out_key = key.clone()
|
||||||
pos_encoding_ops.rotary_embedding_neox(
|
pos_encoding_ops.rotary_embedding_neox(
|
||||||
|
positions,
|
||||||
out_query,
|
out_query,
|
||||||
out_key,
|
out_key,
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
cos_sin_cache,
|
cos_sin_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user