From 897cb2ae28e93de1b22ecfbffcccfb9493f8f4d9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 2 Apr 2023 00:30:17 -0700 Subject: [PATCH] Optimize data movement (#20) --- cacheflow/models/activation.py | 20 +++++++ cacheflow/models/attention.py | 92 +++++++++++++++--------------- cacheflow/models/input_metadata.py | 5 ++ cacheflow/models/llama.py | 19 +++--- cacheflow/models/opt.py | 7 +-- cacheflow/worker/worker.py | 8 +++ csrc/activation.cpp | 12 ++++ csrc/activation_kernels.cu | 46 +++++++++++++++ csrc/attention_kernels.cu | 22 ++++--- csrc/cache_kernels.cu | 25 +++++--- csrc/pos_encoding.cpp | 2 - csrc/pos_encoding_kernels.cu | 49 +++++++--------- setup.py | 7 +++ tests/kernels/activation.py | 30 ++++++++++ tests/kernels/attention.py | 46 ++++++++++----- tests/kernels/cache.py | 10 ++-- tests/kernels/pos_encoding.py | 10 ++-- 17 files changed, 275 insertions(+), 135 deletions(-) create mode 100644 cacheflow/models/activation.py create mode 100644 csrc/activation.cpp create mode 100644 csrc/activation_kernels.cu create mode 100644 tests/kernels/activation.py diff --git a/cacheflow/models/activation.py b/cacheflow/models/activation.py new file mode 100644 index 0000000000000..c3267ebcb7d45 --- /dev/null +++ b/cacheflow/models/activation.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn + +from cacheflow import activation_ops + + +class SiluAndMul(nn.Module): + + def __init__(self): + super().__init__() + + def forward( + self, + x: torch.Tensor, # (num_tokens, 2 * d) + ) -> torch.Tensor: # (num_tokens, d) + num_tokens = x.shape[0] + d = x.shape[1] // 2 + out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device) + activation_ops.silu_and_mul(out, x) + return out diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 8b132e4423798..6fa197e7c8b90 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,6 +1,6 @@ -from typing import List, Optional +from typing import Optional -from flash_attn.flash_attention import FlashAttention +from flash_attn.flash_attn_interface import _flash_attn_forward import torch import torch.nn as nn @@ -16,40 +16,38 @@ class GPTCacheFlowAttention(nn.Module): super().__init__() self.scale = float(scale) - self.flash_attn = FlashAttention(softmax_scale=self.scale) - def multi_query_kv_attention( self, - output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - prompt_lens: List[int], + output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1] + max_prompt_len: int, ) -> None: if query.dtype == torch.float: raise ValueError('The float data type is not supported by ' 'FlashAttention. Use the half data type instead.') - head_size = query.shape[2] + head_size = query.shape[-1] if head_size > 128: raise ValueError('FlashAttention does not support head_size > 128.') - device = query.device - prefix_sum = [0] - for prompt_len in prompt_lens: - prefix_sum.append(prefix_sum[-1] + prompt_len) - prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device) - max_prompt_len = max(prompt_lens) - - # FIXME(woosuk): Unnecessary copy. Optimize this. - qkv = torch.stack([query, key, value], dim=1) - out = self.flash_attn( - qkv, - cu_seqlens=prefix_sum, - max_s=max_prompt_len, + # Directly call FlashAttention's internal function to avoid allocating + # a new tensor for the output. + _flash_attn_forward( + query, + key, + value, + output, + cumulative_prompt_lens, + cumulative_prompt_lens, + max_prompt_len, + max_prompt_len, + dropout_p=0.0, + softmax_scale=self.scale, causal=True, - )[0] - # FIXME(woosuk): Unnecessary copy. Optimize this. - output.copy_(out, non_blocking=True) + return_softmax=False, + ) def single_query_cached_kv_attention( self, @@ -90,21 +88,18 @@ class GPTCacheFlowAttention(nn.Module): input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] - # Pre-allocate the output tensor. - output = torch.empty_like(query) + # NOTE: The query, key, and value tensors must be sliced from a qkv + # tensor of shape [num_tokens, 3 * num_heads * head_size]. - # Prune out paddings if any. - query = query[:input_metadata.num_valid_tokens] - key = key[:input_metadata.num_valid_tokens] - value = value[:input_metadata.num_valid_tokens] - - # Reshape the input tensors. + # Reshape the query, key, and value tensors. num_heads = value_cache.shape[1] head_size = value_cache.shape[2] query = query.view(-1, num_heads, head_size) key = key.view(-1, num_heads, head_size) value = value.view(-1, num_heads, head_size) - output = output.view(-1, num_heads, head_size) + + # Pre-allocate the output tensor. + output = torch.empty_like(query) # Compute the attention op for prompts. num_prompt_tokens = input_metadata.num_prompt_tokens @@ -114,7 +109,8 @@ class GPTCacheFlowAttention(nn.Module): query[:num_prompt_tokens], key[:num_prompt_tokens], value[:num_prompt_tokens], - input_metadata.prompt_lens, + input_metadata.cumulative_prompt_lens, + input_metadata.max_prompt_len, ) # Wait until the cache op is done. @@ -122,14 +118,22 @@ class GPTCacheFlowAttention(nn.Module): cache_event.wait() # Reshape the keys and values and store them in the cache. - cache_ops.reshape_and_cache( - key, value, key_cache, value_cache, input_metadata.slot_mapping) + num_valid_tokens = input_metadata.num_valid_tokens + if num_valid_tokens > 0: + # The stride is 3 because the key and value are sliced from qkv. + cache_ops.reshape_and_cache( + key[:num_valid_tokens], + value[:num_valid_tokens], + key_cache, + value_cache, + input_metadata.slot_mapping, + ) if input_metadata.num_generation_tokens > 0: # Compute the attention op for generation tokens. self.single_query_cached_kv_attention( - output[num_prompt_tokens:], - query[num_prompt_tokens:], + output[num_prompt_tokens:num_valid_tokens], + query[num_prompt_tokens:num_valid_tokens], key_cache, value_cache, input_metadata) @@ -186,19 +190,15 @@ class LlamaCacheFlowAttention(GPTCacheFlowAttention): ) -> torch.Tensor: # [num_tokens, num_heads * head_size] # Apply rotary embedding to the query and key before passing them # to the attention op. - out_query = torch.empty_like(query) - out_key = torch.empty_like(key) pos_encoding_ops.rotary_embedding_neox( - out_query, - out_key, positions, query, key, self.cos_sin_cache, ) return super().forward( - out_query, - out_key, + query, + key, value, key_cache, value_cache, diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 8a341fbac6276..c61bfff20a66b 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -12,6 +12,7 @@ class InputMetadata: seq_groups: List[Tuple[List[int], SamplingParams]], seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. prompt_lens: List[int], + cumulative_prompt_lens: torch.Tensor, slot_mapping: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, @@ -20,6 +21,7 @@ class InputMetadata: self.seq_groups = seq_groups self.seq_logprobs = seq_logprobs self.prompt_lens = prompt_lens + self.cumulative_prompt_lens = cumulative_prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens self.max_context_len = max_context_len @@ -27,6 +29,7 @@ class InputMetadata: self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) + self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: @@ -40,11 +43,13 @@ class InputMetadata: return (f'InputMetadata(' f'num_prompts={self.num_prompts}, ' f'num_prompt_tokens={self.num_prompt_tokens}, ' + f'max_prompt_len={self.max_prompt_len}, ' f'num_generation_tokens={self.num_generation_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, ' f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' f'max_context_len={self.max_context_len}), ' f'prompt_lens={self.prompt_lens}, ' + f'cumulative_prompt_lens={self.cumulative_prompt_lens}, ' f'slot_mapping={self.slot_mapping}, ' f'context_lens={self.context_lens}, ' f'block_tables={self.block_tables})') diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 86f350bd11387..2a3c8b007adf1 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -11,6 +11,7 @@ from torch import nn from transformers import LlamaConfig from cacheflow.models import InputMetadata +from cacheflow.models.activation import SiluAndMul from cacheflow.models.attention import LlamaCacheFlowAttention from cacheflow.models.layernorm import RMSNorm from cacheflow.models.sample import Sampler @@ -39,16 +40,14 @@ class LlamaMLP(nn.Module): self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False) - assert hidden_act == 'silu' - self.act_fn = nn.SiLU() + if hidden_act != 'silu': + raise ValueError(f'Unsupported activation: {hidden_act}. ' + 'Only silu is supported for now.') + self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) - gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1)) - gate, up = torch.split(gate_up, 1, dim=-2) - gate = gate.squeeze(dim=-2).contiguous() - up = up.squeeze(dim=-2).contiguous() - x = self.act_fn(gate) * up + x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x @@ -94,11 +93,7 @@ class LlamaAttention(nn.Module): cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) - q, k, v = torch.split(qkv, 1, dim=-2) - q = q.squeeze(dim=-2).contiguous() - k = k.squeeze(dim=-2).contiguous() - v = v.squeeze(dim=-2).contiguous() + q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache attn_output = self.attn( positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index eed20ea41ffdd..9ecd9e70f138a 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -69,17 +69,14 @@ class OPTAttention(nn.Module): cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) - q, k, v = torch.split(qkv, 1, dim=-2) - q = q.squeeze(dim=-2).contiguous() - k = k.squeeze(dim=-2).contiguous() - v = v.squeeze(dim=-2).contiguous() + q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) output, _ = self.out_proj(attn_output) return output + class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index f309cf12f672d..db0d46aabe9e1 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -128,6 +128,11 @@ class Worker: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + cumulative_prompt_lens: List[int] = [0] + for prompt_len in prompt_lens: + cumulative_prompt_lens.append( + cumulative_prompt_lens[-1] + prompt_len) + # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -183,11 +188,14 @@ class Worker: for block_table in generation_block_tables] block_tables_tensor = torch.tensor( padded_block_tables, dtype=torch.int, device='cuda') + cumulative_prompt_lens_tensor = torch.tensor( + cumulative_prompt_lens, dtype=torch.int, device='cuda') input_metadata = InputMetadata( seq_groups=seq_groups, seq_logprobs=seq_logprobs, prompt_lens=prompt_lens, + cumulative_prompt_lens=cumulative_prompt_lens_tensor, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, max_context_len=max_context_len, diff --git a/csrc/activation.cpp b/csrc/activation.cpp new file mode 100644 index 0000000000000..f95afdc00e3c9 --- /dev/null +++ b/csrc/activation.cpp @@ -0,0 +1,12 @@ +#include + +void silu_and_mul( + torch::Tensor& out, + torch::Tensor& input); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "silu_and_mul", + &silu_and_mul, + "Activation function used in SwiGLU."); +} diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu new file mode 100644 index 0000000000000..12ee6c54827c3 --- /dev/null +++ b/csrc/activation_kernels.cu @@ -0,0 +1,46 @@ +#include +#include + +namespace cacheflow { + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T) (((float) x) / (1.0f + expf((float) -x))); +} + +template +__global__ void silu_and_mul_kernel( + scalar_t* __restrict__ out, // [num_tokens, d] + const scalar_t* __restrict__ input, // [num_tokens, 2, d] + const int d) { + const int token_idx = blockIdx.x; + for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); + const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = silu(x) * y; + } +} + +} // namespace cacheflow + +void silu_and_mul( + torch::Tensor& out, // [num_tokens, d] + torch::Tensor& input) // [num_tokens, 2 * d] +{ + int num_tokens = input.size(0); + int d = input.size(1) / 2; + + dim3 grid(num_tokens); + dim3 block(std::min(d, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), + "silu_and_mul_kernel", + [&] { + cacheflow::silu_and_mul_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + d); + }); +} diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu index 5b24120eadfbf..60c0d0c6cd03c 100644 --- a/csrc/attention_kernels.cu +++ b/csrc/attention_kernels.cu @@ -25,7 +25,8 @@ __global__ void single_query_cached_kv_attention_kernel( const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq) { + const int max_num_blocks_per_seq, + const int q_stride) { constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; @@ -56,7 +57,8 @@ __global__ void single_query_cached_kv_attention_kernel( // For example, if the the thread group size is 4, then the first thread in the group // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... // th vectors of the query, and so on. - const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; Q_vec q_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { @@ -264,7 +266,8 @@ __global__ void single_query_cached_kv_attention_kernel( scale, \ block_tables_ptr, \ context_lens_ptr, \ - max_num_blocks_per_seq); + max_num_blocks_per_seq, \ + query_stride); // TODO(woosuk): Tune NUM_THREADS. template< @@ -284,6 +287,7 @@ void single_query_cached_kv_attention_launcher( int num_heads = query.size(1); int head_size = query.size(2); int max_num_blocks_per_seq = block_tables.size(1); + int query_stride = query.stride(0); T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -333,13 +337,13 @@ void single_query_cached_kv_attention_launcher( } void single_query_cached_kv_attention( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len) { // TODO(woosuk): Support BF16. diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d7a0faa814108..8b5537c47229b 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -81,6 +81,8 @@ __global__ void reshape_and_cache_kernel( scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, + const int value_stride, const int num_heads, const int head_size, const int block_size, @@ -92,7 +94,8 @@ __global__ void reshape_and_cache_kernel( const int n = num_heads * head_size; for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int src_idx = token_idx * n + i; + const int src_key_idx = token_idx * key_stride + i; + const int src_value_idx = token_idx * value_stride + i; const int head_idx = i / head_size; const int head_offset = i % head_size; @@ -108,25 +111,29 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = __ldg(&key[src_idx]); - value_cache[tgt_value_idx] = __ldg(&value[src_idx]); + key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); + value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); } } } // namespace cacheflow void reshape_and_cache( - torch::Tensor& key, - torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping) { + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping) // [num_tokens] +{ int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3); int x = key_cache.size(4); + int key_stride = key.stride(0); + int value_stride = value.stride(0); + dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -140,6 +147,8 @@ void reshape_and_cache( key_cache.data_ptr(), value_cache.data_ptr(), slot_mapping.data_ptr(), + key_stride, + value_stride, num_heads, head_size, block_size, diff --git a/csrc/pos_encoding.cpp b/csrc/pos_encoding.cpp index a10bec85a98a7..5966751a489f2 100644 --- a/csrc/pos_encoding.cpp +++ b/csrc/pos_encoding.cpp @@ -1,8 +1,6 @@ #include void rotary_embedding_neox( - torch::Tensor& out_query, - torch::Tensor& out_key, torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 50cf209fb200d..525f0fef4a3ae 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -5,12 +5,11 @@ namespace cacheflow { template __global__ void rotary_embedding_neox_kernel( - scalar_t* __restrict__ out_query, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ out_key, // [num_tokens, num_heads, head_size] const int64_t* __restrict__ positions, // [num_tokens] - const scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2] + const int stride, const int num_heads, const int head_size) { // Each thread block is responsible for one token. @@ -19,41 +18,36 @@ __global__ void rotary_embedding_neox_kernel( const scalar_t* cache_ptr = cos_sin_cache + pos * head_size; const int embed_dim = head_size / 2; - const int n = num_heads * head_size; + const int n = num_heads * embed_dim; for (int i = threadIdx.x; i < n; i += blockDim.x) { - const int idx = token_idx * n + i; + const int head_idx = i / embed_dim; + const int token_head = token_idx * stride + head_idx * head_size; - const int head_idx = i / head_size; - const int head_offset = i % head_size; - const int token_head = token_idx * n + head_idx * head_size; - - const bool is_first_half = head_offset < embed_dim; - const int rot_offset = head_offset % embed_dim; + const int rot_offset = i % embed_dim; const int x_index = rot_offset; const int y_index = embed_dim + rot_offset; + const int out_x = token_idx * stride + head_idx * head_size + x_index; + const int out_y = token_idx * stride + head_idx * head_size + y_index; + const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t sin = __ldg(cache_ptr + y_index); - const scalar_t q_x = __ldg(query + token_head + x_index); - const scalar_t q_y = __ldg(query + token_head + y_index); - const scalar_t q_cos = is_first_half ? q_x : q_y; - const scalar_t q_sin = is_first_half ? -q_y : q_x; - out_query[idx] = q_cos * cos + q_sin * sin; + const scalar_t q_x = query[token_head + x_index]; + const scalar_t q_y = query[token_head + y_index]; + query[out_x] = q_x * cos - q_y * sin; + query[out_y] = q_y * cos + q_x * sin; - const scalar_t k_x = __ldg(key + token_head + x_index); - const scalar_t k_y = __ldg(key + token_head + y_index); - const scalar_t k_cos = is_first_half ? k_x : k_y; - const scalar_t k_sin = is_first_half ? -k_y : k_x; - out_key[idx] = k_cos * cos + k_sin * sin; + const scalar_t k_x = key[token_head + x_index]; + const scalar_t k_y = key[token_head + y_index]; + key[out_x] = k_x * cos - k_y * sin; + key[out_y] = k_y * cos + k_x * sin; } } } // namespace cacheflow void rotary_embedding_neox( - torch::Tensor& out_query, // [num_tokens, num_heads * head_size] - torch::Tensor& out_key, // [num_tokens, num_heads * head_size] torch::Tensor& positions, // [num_tokens] torch::Tensor& query, // [num_tokens, num_heads * head_size] torch::Tensor& key, // [num_tokens, num_heads * head_size] @@ -62,21 +56,22 @@ void rotary_embedding_neox( int num_tokens = query.size(0); int head_size = cos_sin_cache.size(1); int num_heads = query.size(1) / head_size; + int stride = query.stride(0); + TORCH_CHECK(stride == key.stride(0)); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); + dim3 block(std::min(num_heads * head_size / 2, 512)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( query.scalar_type(), "rotary_embedding_neox", [&] { cacheflow::rotary_embedding_neox_kernel<<>>( - out_query.data_ptr(), - out_key.data_ptr(), positions.data_ptr(), query.data_ptr(), key.data_ptr(), cos_sin_cache.data_ptr(), + stride, num_heads, head_size); }); diff --git a/setup.py b/setup.py index df7551989946d..e96c730333795 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,13 @@ layernorm_extension = cpp_extension.CUDAExtension( ) ext_modules.append(layernorm_extension) +activation_extension = cpp_extension.CUDAExtension( + name='cacheflow.activation_ops', + sources=['csrc/activation.cpp', 'csrc/activation_kernels.cu'], + extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, +) +ext_modules.append(activation_extension) + setuptools.setup( name='cacheflow', ext_modules=ext_modules, diff --git a/tests/kernels/activation.py b/tests/kernels/activation.py new file mode 100644 index 0000000000000..3d9a9a644f6d0 --- /dev/null +++ b/tests/kernels/activation.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F + +from cacheflow import activation_ops + + +def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(chunks=2, dim=1) + return F.silu(x1) * x2 + + +@torch.inference_mode() +def test_silu_and_mul( + num_tokens: int, + d: int, + dtype: torch.dtype, +) -> None: + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda') + out = torch.empty(num_tokens, d, dtype=dtype, device='cuda') + activation_ops.silu_and_mul(out, x) + ref_out = ref_silu_and_mul(x) + assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) + + +if __name__ == '__main__': + for dtype in [torch.half, torch.float]: + for num_tokens in [7, 83, 2048]: + for d in [512, 4096, 13824]: + print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') + test_silu_and_mul(num_tokens, d, dtype) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index b6766e1eddc26..409da9efa2ef9 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -1,7 +1,7 @@ import random from typing import List, Optional -from flash_attn.flash_attention import FlashAttention +from flash_attn.flash_attn_interface import _flash_attn_forward import torch from cacheflow import attention_ops @@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention( num_blocks: int, dtype: torch.dtype, ) -> None: - query = torch.randn( - num_tokens, num_heads, head_size, dtype=dtype, device='cuda') + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + query, _, _ = qkv.unbind(dim=1) x = 16 // torch.tensor([], dtype=dtype).element_size() key_block_shape = (num_heads, head_size // x, block_size, x) key_cache = torch.randn( @@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention( value_cache = torch.randn( size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') + # Adjust the range of the values to reduce precision errors. + query = query / (head_size ** 0.5) + key_cache = key_cache / (head_size ** 0.5) + value_cache = value_cache / (head_size ** 0.5) + context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] max_context_len = max(context_lens) context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') @@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention( block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') scale = float(1.0 / (head_size ** 0.5)) - output = torch.empty_like(query) + output = torch.empty( + num_tokens, num_heads, head_size, dtype=dtype, device='cuda') attention_ops.single_query_cached_kv_attention( output, query, @@ -175,19 +182,28 @@ def test_multi_query_kv_attention( cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') scale = float(1.0 / (head_size ** 0.5)) - query = torch.randn( - num_tokens, num_heads, head_size, dtype=dtype, device='cuda') - key = torch.rand_like(query) - value = torch.rand_like(query) + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + # Adjust the range of the values to reduce precision errors. + qkv = qkv / (head_size ** 0.5) - qkv = torch.stack([query, key, value], dim=1) - flash_attn = FlashAttention(softmax_scale=scale) - output = flash_attn( - qkv, - cu_seqlens=cu_seq_lens, - max_s=max_seq_len, + query, key, value = qkv.unbind(dim=1) + output = torch.empty( + num_tokens, num_heads, head_size, dtype=dtype, device='cuda') + _flash_attn_forward( + query, + key, + value, + output, + cu_seq_lens, + cu_seq_lens, + max_seq_len, + max_seq_len, + dropout_p=0.0, + softmax_scale=scale, causal=True, - )[0] + return_softmax=False, + ) cu_seq_lens = cu_seq_lens.cpu().tolist() ref_output = ref_multi_query_kv_attention( diff --git a/tests/kernels/cache.py b/tests/kernels/cache.py index 9eebe437448f8..d6b1c3d2dd480 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/cache.py @@ -17,10 +17,10 @@ def test_reshape_and_cache( slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda') - kv_shape = (num_tokens, num_heads, head_size) - key = torch.randn(size=kv_shape, dtype=dtype, device='cuda') - value = torch.randn(size=kv_shape, dtype=dtype, device='cuda') - + qkv = torch.randn( + num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + _, key, value = qkv.unbind(dim=1) + x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') @@ -35,7 +35,7 @@ def test_reshape_and_cache( for i in range(num_tokens): reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) - block_idx = slot_mapping[i] // block_size + block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor') block_offset = slot_mapping[i] % block_size cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i] diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/pos_encoding.py index 2dbce545e3455..502eedfbdcf59 100644 --- a/tests/kernels/pos_encoding.py +++ b/tests/kernels/pos_encoding.py @@ -85,15 +85,13 @@ def test_rotary_embedding_neox( cos_sin_cache = torch.cat((cos, sin), dim=-1) cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') - # Run the kernel. - out_query = torch.empty_like(query) - out_key = torch.empty_like(key) + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() pos_encoding_ops.rotary_embedding_neox( + positions, out_query, out_key, - positions, - query, - key, cos_sin_cache, )