diff --git a/csrc/cache.h b/csrc/cache.h index 2ee9c0363a071..55ed30bd8ce48 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -28,7 +28,7 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale); -void concat_and_cache_mla(torch::Tensor& ckv, torch::Tensor& k_pe, +void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, torch::Tensor& kv_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, torch::Tensor& scale); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0d1a52fc551b9..23a46b6ed8ad8 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -248,13 +248,13 @@ __global__ void reshape_and_cache_flash_kernel( template __global__ void concat_and_cache_mla_kernel( - const scalar_t* __restrict__ ckv, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank // + pe_dim)] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, // - const int ckv_stride, // + const int kv_c_stride, // const int k_pe_stride, // const int kv_lora_rank, // const int pe_dim, // @@ -286,7 +286,7 @@ __global__ void concat_and_cache_mla_kernel( } }; - copy(ckv, kv_cache, ckv_stride, block_stride, kv_lora_rank, 0); + copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0); copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } @@ -391,18 +391,18 @@ void reshape_and_cache_flash( // KV_T is the stored data type of kv-cache. // CACHE_T is the data type of key and value tensors. // KV_DTYPE is the real data type of kv-cache. -#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ - vllm::concat_and_cache_mla_kernel \ - <<>>( \ - reinterpret_cast(ckv.data_ptr()), \ - reinterpret_cast(k_pe.data_ptr()), \ - reinterpret_cast(kv_cache.data_ptr()), \ - slot_mapping.data_ptr(), block_stride, ckv_stride, \ - k_pe_stride, kv_lora_rank, pe_dim, block_size, \ +#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, kv_c_stride, \ + k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); void concat_and_cache_mla( - torch::Tensor& ckv, // [num_tokens, kv_lora_rank] + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank + // pe_dim)] @@ -419,22 +419,22 @@ void concat_and_cache_mla( // For compatibility with both cases, we use slot_mapping.size(0) as the // number of tokens. int num_tokens = slot_mapping.size(0); - int kv_lora_rank = ckv.size(1); + int kv_lora_rank = kv_c.size(1); int pe_dim = k_pe.size(1); int block_size = kv_cache.size(1); TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); - int ckv_stride = ckv.stride(0); + int kv_c_stride = kv_c.stride(0); int k_pe_stride = k_pe.stride(0); int block_stride = kv_cache.stride(0); dim3 grid(num_tokens); dim3 block(std::min(kv_lora_rank, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(ckv)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(ckv.dtype(), kv_cache_dtype, + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, CALL_CONCAT_AND_CACHE_MLA); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8a67cd26482da..1846d9ac29943 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -463,9 +463,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); - // Concat ckv and k_pe and cache them. + // Concat kv_c and k_pe and cache them. cache_ops.def( - "concat_and_cache_mla(Tensor ckv, Tensor k_pe," + "concat_and_cache_mla(Tensor kv_c, Tensor k_pe," " Tensor! kv_cache," " Tensor slot_mapping," " str kv_cache_dtype," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 80abb68fbe949..2dae174685424 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -981,14 +981,14 @@ def reshape_and_cache_flash( def concat_and_cache_mla( - ckv: torch.Tensor, + kv_c: torch.Tensor, k_pe: torch.Tensor, kv_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, scale: torch.Tensor, ) -> None: - torch.ops._C_cache_ops.concat_and_cache_mla(ckv, k_pe, kv_cache, + torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 0ccf58970f43c..7ad242b7001fa 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -269,7 +269,7 @@ class AttentionImpl(ABC, Generic[T]): self, layer: AttentionLayer, query: torch.Tensor, # For MLA hidden_states_or_cq - key: torch.Tensor, # For MLA ckv_normed + key: torch.Tensor, # For MLA kv_c_normed value: torch.Tensor, # For MLA k_pe kv_cache: torch.Tensor, attn_metadata: T, diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 9cdc31d2bd7d9..a3b45fadffa6b 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -52,7 +52,7 @@ class MLAImplCommon(AttentionImpl): 1. The hidden states (B, H) are projected down into cq (B, Lq) and kv_c_k_pe (B, Lkv+R). - 2. The kv_c_k_pe is split into ckv (B, Lkv) and k_pe (B, R). cq + 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq and kv_c are normalized. # @@ -249,7 +249,7 @@ class MLAImplCommon(AttentionImpl): def _forward_prefill( self, q: torch.Tensor, - ckv_normed: torch.Tensor, + kv_c_normed: torch.Tensor, k_pe: torch.Tensor, attn_metadata: MLAMetadataCommon, ) -> torch.Tensor: