diff --git a/csrc/cache.h b/csrc/cache.h index f1bd802b80eb6..46b89284c07e2 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -29,13 +29,10 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_flash_full_cuda( - torch::Tensor& tokenshape, - torch::Tensor& key, torch::Tensor& value, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& tokenshape, torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, torch::Tensor& kv_cache, torch::Tensor& slot_mapping, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index f660c0c39275b..a1ffd7bf6068c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -262,8 +262,8 @@ __global__ void reshape_and_cache_flash_full_cuda_kernel( const int64_t token_idx = blockIdx.x; int32_t unpadded_num_tokens = tensorshape[0]; - if(token_idx >= unpadded_num_tokens) { - return; + if (token_idx >= unpadded_num_tokens) { + return; } const int64_t slot_idx = slot_mapping[token_idx]; @@ -437,27 +437,28 @@ void reshape_and_cache_flash( // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE)\ - vllm::reshape_and_cache_flash_full_cuda_kernel \ - <<>>( \ - reinterpret_cast(tokenshape.data_ptr()), \ - reinterpret_cast(key.data_ptr()), \ - reinterpret_cast(value.data_ptr()), \ - reinterpret_cast(key_cache.data_ptr()), \ - reinterpret_cast(value_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); +#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash_full_cuda( - torch::Tensor& tokenshape, // true num_tokens at first entry. - torch::Tensor& key, // [num_tokens, num_heads, head_size] - torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& tokenshape, // true num_tokens at first entry. + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { int padded_num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 46a3225262c73..3643ca18f73b7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -478,7 +478,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash_full_cuda", torch::kCUDA, &reshape_and_cache_flash_full_cuda); diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 7c9e9e14de0be..efe1291b340de 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -215,16 +215,16 @@ class FlashAttentionImpl(AttentionImpl): # Regular attention (common case). num_actual_tokens = attn_metadata.num_actual_tokens batch_size = attn_metadata.block_table.shape[0] - + #TODO: Do we need to slice by [:batch_size+1]? flash_attn_varlen_func( q=query[:num_padded_tokens], k=key_cache, v=value_cache, out=output[:num_padded_tokens], - cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1], + cu_seqlens_q=attn_metadata.query_start_loc[:batch_size + 1], max_seqlen_q=attn_metadata.max_query_len, - seqused_k=attn_metadata.seq_lens[:batch_size+1], + seqused_k=attn_metadata.seq_lens[:batch_size], max_seqlen_k=attn_metadata.max_seq_len, softmax_scale=self.scale, causal=True, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f062c8a86c7f2..97e06e783ab29 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -122,9 +122,9 @@ class GPUModelRunner: vocab_size=model_config.get_vocab_size(), ) -# self.use_cuda_graph = (self.vllm_config.compilation_config.level -# == CompilationLevel.PIECEWISE -# and not self.model_config.enforce_eager) + # self.use_cuda_graph = (self.vllm_config.compilation_config.level + # == CompilationLevel.PIECEWISE + # and not self.model_config.enforce_eager) self.use_cuda_graph = not self.model_config.enforce_eager # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -221,9 +221,10 @@ class GPUModelRunner: pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() - self.tokenshape_cpu = torch.zeros(4, dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.tokenshape_cpu = torch.zeros(4, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -459,16 +460,18 @@ class GPUModelRunner: self.query_start_loc[:num_reqs + 1].copy_( self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) - self.seq_lens[:num_reqs].copy(self.seq_lens_cpu[:num_reqs], - non_blocking=True) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) self.slot_mapping[:total_num_scheduled_tokens].copy_( self.slot_mapping_cpu[:total_num_scheduled_tokens], non_blocking=True) - self.tokenshape_cpu[0] = total_num_scheduled_tokens # Actual number of tokens to process - self.tokenshape_cpu[1] = num_reqs # Number of requests - self.tokenshape_cpu[2] = max_num_scheduled_tokens # Maximum query length - self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length + self.tokenshape_cpu[ + 0] = total_num_scheduled_tokens # Tokens to process + self.tokenshape_cpu[1] = num_reqs # Number of requests + self.tokenshape_cpu[ + 2] = max_num_scheduled_tokens # Maximum query length + self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length self.tokenshape.copy_(self.tokenshape_cpu, non_blocking=True) # Prepare for cascade attention if needed. @@ -908,7 +911,7 @@ class GPUModelRunner: inputs_embeds=inputs_embeds, ) return hidden_states - + def metadata_for_dummy_run(self, num_tokens) -> FlashAttentionMetadata: # Create placeholder metadata num_reqs = num_tokens @@ -919,16 +922,17 @@ class GPUModelRunner: max_query_len=max_query_len, query_start_loc=self.query_start_loc, max_seq_len=max_seq_len, - seq_start_loc=self.seq_start_loc, - block_table=self.input_batch.block_table[:num_reqs], + seq_lens=self.seq_lens, + block_table=( + self.input_batch.block_table.get_device_tensor()[:num_reqs]), slot_mapping=self.slot_mapping, tokenshape=self.tokenshape, # Cascade stuff. Non-piecewise CUDA graphs NYI - use_cascade=None, + use_cascade=False, common_prefix_len=0, cu_prefix_query_lens=None, - cu_prefix_kv_lens=None, - cu_suffix_kv_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, ) def profile_run(self) -> None: @@ -1058,8 +1062,8 @@ class GPUModelRunner: attn_metadata = self.metadata_for_dummy_run(num_tokens) for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(self.model, num_tokens, attn_metadata=attn_metadata) - self._dummy_run(self.model, num_tokens, attn_metadata=attn_metadata) + self._dummy_run(num_tokens, attn_metadata=attn_metadata) + self._dummy_run(num_tokens, attn_metadata=attn_metadata) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0]