renaming for consistency

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-01-30 04:00:26 +00:00
parent 534cd0006d
commit 2326814c11
6 changed files with 24 additions and 24 deletions

View File

@ -28,7 +28,7 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void concat_and_cache_mla(torch::Tensor& ckv, torch::Tensor& k_pe,
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);

View File

@ -248,13 +248,13 @@ __global__ void reshape_and_cache_flash_kernel(
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ ckv, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int ckv_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
@ -286,7 +286,7 @@ __global__ void concat_and_cache_mla_kernel(
}
};
copy(ckv, kv_cache, ckv_stride, block_stride, kv_lora_rank, 0);
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
@ -391,18 +391,18 @@ void reshape_and_cache_flash(
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(ckv.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, ckv_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& ckv, // [num_tokens, kv_lora_rank]
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
@ -419,22 +419,22 @@ void concat_and_cache_mla(
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = ckv.size(1);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
int ckv_stride = ckv.stride(0);
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(ckv));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(ckv.dtype(), kv_cache_dtype,
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}

View File

@ -463,9 +463,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Concat ckv and k_pe and cache them.
// Concat kv_c and k_pe and cache them.
cache_ops.def(
"concat_and_cache_mla(Tensor ckv, Tensor k_pe,"
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"

View File

@ -981,14 +981,14 @@ def reshape_and_cache_flash(
def concat_and_cache_mla(
ckv: torch.Tensor,
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.concat_and_cache_mla(ckv, k_pe, kv_cache,
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
slot_mapping, kv_cache_dtype,
scale)

View File

@ -269,7 +269,7 @@ class AttentionImpl(ABC, Generic[T]):
self,
layer: AttentionLayer,
query: torch.Tensor, # For MLA hidden_states_or_cq
key: torch.Tensor, # For MLA ckv_normed
key: torch.Tensor, # For MLA kv_c_normed
value: torch.Tensor, # For MLA k_pe
kv_cache: torch.Tensor,
attn_metadata: T,

View File

@ -52,7 +52,7 @@ class MLAImplCommon(AttentionImpl):
1. The hidden states (B, H) are projected down into cq (B, Lq) and
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into ckv (B, Lkv) and k_pe (B, R). cq
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.
#
@ -249,7 +249,7 @@ class MLAImplCommon(AttentionImpl):
def _forward_prefill(
self,
q: torch.Tensor,
ckv_normed: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: MLAMetadataCommon,
) -> torch.Tensor: