From c33aeecf24b19e51df115f8b819435951bf96110 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Mon, 3 Feb 2025 22:42:09 +0000 Subject: [PATCH] simplify - get rid of tokenshape Signed-off-by: Tyler Michael Smith --- csrc/cache.h | 6 ---- csrc/cache_kernels.cu | 44 ------------------------ csrc/torch_bindings.cpp | 12 ------- vllm/v1/attention/backends/flash_attn.py | 20 ++++------- vllm/v1/worker/gpu_model_runner.py | 18 ++-------- 5 files changed, 10 insertions(+), 90 deletions(-) diff --git a/csrc/cache.h b/csrc/cache.h index 46b89284c07e2..55ed30bd8ce48 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,12 +28,6 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale); -void reshape_and_cache_flash_full_cuda( - torch::Tensor& tokenshape, torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - torch::Tensor& k_scale, torch::Tensor& v_scale); - void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, torch::Tensor& kv_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index a1ffd7bf6068c..2b3d8f94dbc13 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -434,50 +434,6 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } -// KV_T is the stored data type of kv-cache. -// CACHE_T is the data type of key and value tensors. -// KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE) \ - vllm::reshape_and_cache_flash_kernel \ - <<>>( \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, \ - reinterpret_cast(k_scale.data_ptr()), \ - reinterpret_cast(v_scale.data_ptr())); - -void reshape_and_cache_flash_full_cuda( - torch::Tensor& tokenshape, // true num_tokens at first entry. - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& - value_cache, // [num_blocks, block_size, num_heads, head_size] - torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { - int padded_num_tokens = slot_mapping.size(0); - int num_heads = key.size(1); - int head_size = key.size(2); - int block_size = key_cache.size(1); - - int key_stride = key.stride(0); - int value_stride = value.stride(0); - int block_stride = key_cache.stride(0); - TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); - - dim3 grid(padded_num_tokens); - dim3 block(std::min(num_heads * head_size, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, - CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA); -} - #define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ vllm::concat_and_cache_mla_kernel \ <<>>( \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3643ca18f73b7..186e9c0e81b77 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -470,18 +470,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); - // Reshape the key and value tensors and cache them. - cache_ops.def( - "reshape_and_cache_flash_full_cuda(Tensor tensorshape," - " Tensor key, Tensor value," - " Tensor! key_cache," - " Tensor! value_cache," - " Tensor slot_mapping," - " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); - cache_ops.impl("reshape_and_cache_flash_full_cuda", torch::kCUDA, - &reshape_and_cache_flash_full_cuda); - // Concat kv_c and k_pe and cache them. cache_ops.def( "concat_and_cache_mla(Tensor kv_c, Tensor k_pe," diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index efe1291b340de..9b69b28858fe3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -75,9 +75,6 @@ class FlashAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor - # [num_actual_tokens, batch_size, max_query_len, max_seq_len] - tokenshape: torch.Tensor - # For cascade attention. use_cascade: bool common_prefix_len: int @@ -194,17 +191,15 @@ class FlashAttentionImpl(AttentionImpl): # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - tokenshape = attn_metadata.tokenshape - num_padded_tokens = key.shape[0] + num_actual_tokens = attn_metadata.num_actual_tokens # Reshape the input keys and values and store them in the cache. key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash_full_cuda( - tokenshape, + torch.ops._C_cache_ops.reshape_and_cache_flash( key, value, key_cache, value_cache, - attn_metadata.slot_mapping[:num_padded_tokens], + attn_metadata.slot_mapping[:num_actual_tokens], self.kv_cache_dtype, layer._k_scale, layer._v_scale, @@ -213,16 +208,15 @@ class FlashAttentionImpl(AttentionImpl): # Compute attention and update output up to `num_actual_tokens`. if not attn_metadata.use_cascade: # Regular attention (common case). - num_actual_tokens = attn_metadata.num_actual_tokens batch_size = attn_metadata.block_table.shape[0] - + #TODO: Do we need to slice by [:batch_size+1]? flash_attn_varlen_func( - q=query[:num_padded_tokens], + q=query[:num_actual_tokens], k=key_cache, v=value_cache, - out=output[:num_padded_tokens], - cu_seqlens_q=attn_metadata.query_start_loc[:batch_size + 1], + out=output[:num_actual_tokens], + cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1], max_seqlen_q=attn_metadata.max_query_len, seqused_k=attn_metadata.seq_lens[:batch_size], max_seqlen_k=attn_metadata.max_seq_len, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97e06e783ab29..df153fb51c05f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -185,7 +185,6 @@ class GPUModelRunner: # this one must be int64 dtype=torch.int64, device=self.device) - self.tokenshape = torch.zeros(4, dtype=torch.int32, device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np = np.arange(max(self.max_num_reqs + 1, @@ -221,11 +220,6 @@ class GPUModelRunner: pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - self.tokenshape_cpu = torch.zeros(4, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the preempted requests. @@ -466,13 +460,9 @@ class GPUModelRunner: self.slot_mapping_cpu[:total_num_scheduled_tokens], non_blocking=True) - self.tokenshape_cpu[ - 0] = total_num_scheduled_tokens # Tokens to process - self.tokenshape_cpu[1] = num_reqs # Number of requests - self.tokenshape_cpu[ - 2] = max_num_scheduled_tokens # Maximum query length - self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length - self.tokenshape.copy_(self.tokenshape_cpu, non_blocking=True) + self.query_start_loc[num_reqs + 1:].fill_(-1) + self.positions[total_num_scheduled_tokens:].fill_(0) + self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) # Prepare for cascade attention if needed. common_prefix_len = (scheduler_output.num_common_prefix_blocks * @@ -561,7 +551,6 @@ class GPUModelRunner: block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=self.slot_mapping, - tokenshape=self.tokenshape, # Cascade stuff use_cascade=use_cascade, common_prefix_len=common_prefix_len, @@ -926,7 +915,6 @@ class GPUModelRunner: block_table=( self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=self.slot_mapping, - tokenshape=self.tokenshape, # Cascade stuff. Non-piecewise CUDA graphs NYI use_cascade=False, common_prefix_len=0,