diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..1623c2ff8a351 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,6 +28,15 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, const double k_scale, const double v_scale); +void reshape_and_cache_flash_full_cuda( + torch::Tensor& tokenshape, + torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); + // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f9a25a..45675a1b091d1 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -245,6 +245,52 @@ __global__ void reshape_and_cache_flash_kernel( } } } + +template +__global__ void reshape_and_cache_flash_full_cuda_kernel( + const int32_t* __restrict__ tensorshape, + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, + // head_size] + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size, + const float k_scale, const float v_scale) { + const int64_t token_idx = blockIdx.x; + + int32_t unpadded_num_tokens = tensorshape[0]; + if(token_idx >= unpadded_num_tokens) { + return; + } + + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } + } +} } // namespace vllm // KV_T is the stored data type of kv-cache. @@ -339,6 +385,49 @@ void reshape_and_cache_flash( CALL_RESHAPE_AND_CACHE_FLASH); } +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE)\ + vllm::reshape_and_cache_flash_full_cuda_kernel \ + <<>>( \ + reinterpret_cast(tokenshape.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); + +void reshape_and_cache_flash_full_cuda( + torch::Tensor& tokenshape, // true num_tokens at first entry. + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { + int padded_num_tokens = slot_mapping.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = key_cache.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(padded_num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA); +} + namespace vllm { template diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 956258c1001d3..1865b4d8f297d 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -460,6 +460,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache_flash_full_cuda(Tensor tensorshape," + " Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + cache_ops.impl("reshape_and_cache_flash_full_cuda", torch::kCUDA, + &reshape_and_cache_flash_full_cuda); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a8dd628b9cd6f..f2e8e17f06b4b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -514,6 +514,7 @@ class VllmBackend: if not self.compilation_config.use_cudagraph or \ not self.compilation_config.cudagraph_copy_inputs: +# return self.graph return self.split_gm # if we need to copy input buffers for cudagraph diff --git a/vllm/config.py b/vllm/config.py index 580541685f4d6..a60b45da8122d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2705,7 +2705,7 @@ class CompilationConfig(BaseModel): custom_ops: List[str] = Field(default_factory=list) splitting_ops: List[str] = Field(default=None) # type: ignore - use_inductor: bool = True + use_inductor: bool = False candidate_compile_sizes: Optional[List[int]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -3181,8 +3181,7 @@ class VllmConfig: self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_reshape = False -# self.compilation_config.level = CompilationLevel.PIECEWISE - self.compilation_config.level = CompilationLevel.NO_COMPILATION + self.compilation_config.level = CompilationLevel.PIECEWISE self._set_cudagraph_sizes() @@ -3263,8 +3262,7 @@ class VllmConfig: batch_size_capture_list = [] if self.model_config is not None and \ not self.model_config.enforce_eager: - batch_size_capture_list = [1, 2, 4 - ] + [i for i in range(8, 513, 8)] + batch_size_capture_list = [1, 2, 4] + [i for i in range(8, 513, 8)] self.compilation_config.init_with_cudagraph_sizes( batch_size_capture_list) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a691c1dc7032f..3c908f1ae8cbc 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -65,6 +65,9 @@ class FlashAttentionMetadata: block_table: torch.Tensor slot_mapping: torch.Tensor + # [num_actual_tokens, batch_size, max_query_len, max_seq_len] + tokenshape: torch.Tensor + # For cascade attention. use_cascade: bool common_prefix_len: int @@ -155,7 +158,7 @@ class FlashAttentionImpl(AttentionImpl): assert output is not None, "Output tensor must be provided." if attn_metadata is None: - # Profiling run. + # Dynamic shape profiling run. return output # IMPORTANT! @@ -167,19 +170,17 @@ class FlashAttentionImpl(AttentionImpl): # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - num_actual_tokens = attn_metadata.num_actual_tokens + tokenshape = attn_metadata.tokenshape + num_padded_tokens = key.shape[0] # Reshape the input keys and values and store them in the cache. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] and - # value[:num_actual_tokens] because the reshape_and_cache_flash op uses - # the slot_mapping's shape to determine the number of actual tokens. key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( + torch.ops._C_cache_ops.reshape_and_cache_flash_full_cuda( + tokenshape, key, value, key_cache, value_cache, - attn_metadata.slot_mapping[:num_actual_tokens], + attn_metadata.slot_mapping[:num_padded_tokens], self.kv_cache_dtype, k_scale, v_scale, @@ -188,13 +189,15 @@ class FlashAttentionImpl(AttentionImpl): # Compute attention and update output up to `num_actual_tokens`. if not attn_metadata.use_cascade: # Regular attention (common case). + num_actual_tokens = attn_metadata.num_actual_tokens batch_size = attn_metadata.block_table.shape[0] + print(f"q, k v shapes: {query.shape}") flash_attn_varlen_func( - q=query[:num_actual_tokens], + q=query[:num_padded_tokens], k=key_cache, v=value_cache, - out=output[:num_actual_tokens], + out=output[:num_padded_tokens], cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1], max_seqlen_q=attn_metadata.max_query_len, cu_seqlens_k=attn_metadata.seq_start_loc[:batch_size+1], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 41a9c86590c70..d50d322066717 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -110,9 +110,10 @@ class GPUModelRunner: vocab_size=model_config.get_vocab_size(), ) - self.use_cuda_graph = (self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and not self.model_config.enforce_eager) +# self.use_cuda_graph = (self.vllm_config.compilation_config.level +# == CompilationLevel.PIECEWISE +# and not self.model_config.enforce_eager) + self.use_cuda_graph = not self.model_config.enforce_eager # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -149,6 +150,7 @@ class GPUModelRunner: # this one must be int64 dtype=torch.int64, device=self.device) + self.tokenshape = torch.zeros(4, dtype=torch.int32, device=self.device) # OPTIMIZATION: Cache the tensors rather than creating them every step. self.arange_np = np.arange(max(self.max_num_reqs + 1, @@ -183,6 +185,10 @@ class GPUModelRunner: pin_memory=self.pin_memory) self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + self.tokenshape_cpu = torch.zeros(4, dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. @@ -379,6 +385,12 @@ class GPUModelRunner: self.slot_mapping_cpu[:total_num_scheduled_tokens], non_blocking=True) + self.tokenshape_cpu[0] = total_num_scheduled_tokens # Actual number of tokens to process + self.tokenshape_cpu[1] = num_reqs # Number of requests + self.tokenshape_cpu[2] = max_num_scheduled_tokens # Maximum query length + self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length + self.tokenshape.copy_(self.tokenshape_cpu, non_blocking=True) + # Prepare for cascade attention if needed. common_prefix_len = (scheduler_output.num_common_prefix_blocks * self.block_size) @@ -468,6 +480,7 @@ class GPUModelRunner: seq_start_loc=self.seq_start_loc, block_table=self.input_batch.block_table[:num_reqs], slot_mapping=self.slot_mapping, + tokenshape=self.tokenshape, # Cascade stuff use_cascade=use_cascade, common_prefix_len=common_prefix_len, @@ -710,6 +723,7 @@ class GPUModelRunner: model: nn.Module, num_tokens: int, kv_caches: List[torch.Tensor], + attn_metadata: Optional[FlashAttentionMetadata], ) -> torch.Tensor: if self.is_multimodal_model: input_ids = None @@ -717,7 +731,7 @@ class GPUModelRunner: else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None - with set_forward_context(None, self.vllm_config): + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = model( input_ids=input_ids, positions=self.positions[:num_tokens], @@ -726,6 +740,28 @@ class GPUModelRunner: inputs_embeds=inputs_embeds, ) return hidden_states + + def metadata_for_dummy_run(self, num_tokens) -> FlashAttentionMetadata: + # Create placeholder metadata + num_reqs = num_tokens + max_query_len = num_tokens + max_seq_len = num_tokens + return FlashAttentionMetadata( + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + query_start_loc=self.query_start_loc, + max_seq_len=max_seq_len, + seq_start_loc=self.seq_start_loc, + block_table=self.input_batch.block_table[:num_reqs], + slot_mapping=self.slot_mapping, + tokenshape=self.tokenshape, + # Cascade stuff. Non-piecewise CUDA graphs NYI + use_cascade=None, + common_prefix_len=0, + cu_prefix_query_lens=None, + cu_prefix_kv_lens=None, + cu_suffix_kv_lens=None, + ) def profile_run(self) -> None: # use an empty tensor instead of `None`` to force Dynamo to pass @@ -831,7 +867,7 @@ class GPUModelRunner: # Trigger compilation for general shape. hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + dummy_kv_caches, None) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. @@ -849,10 +885,11 @@ class GPUModelRunner: # can reuse the memory pool allocated for the large shapes. with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): + attn_metadata = self.metadata_for_dummy_run(num_tokens) for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) + self._dummy_run(self.model, num_tokens, self.kv_caches, attn_metadata) + self._dummy_run(self.model, num_tokens, self.kv_caches, attn_metadata) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0]