mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-21 14:27:09 +08:00
renaming for consistency
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
534cd0006d
commit
2326814c11
@ -28,7 +28,7 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
void concat_and_cache_mla(torch::Tensor& ckv, torch::Tensor& k_pe,
|
||||
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
@ -248,13 +248,13 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_mla_kernel(
|
||||
const scalar_t* __restrict__ ckv, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int ckv_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
@ -286,7 +286,7 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
}
|
||||
};
|
||||
|
||||
copy(ckv, kv_cache, ckv_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
@ -391,18 +391,18 @@ void reshape_and_cache_flash(
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(ckv.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, ckv_stride, \
|
||||
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
|
||||
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& ckv, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
@ -419,22 +419,22 @@ void concat_and_cache_mla(
|
||||
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
||||
// number of tokens.
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = ckv.size(1);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
|
||||
int ckv_stride = ckv.stride(0);
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(ckv));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(ckv.dtype(), kv_cache_dtype,
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
|
||||
@ -463,9 +463,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||
&reshape_and_cache_flash);
|
||||
|
||||
// Concat ckv and k_pe and cache them.
|
||||
// Concat kv_c and k_pe and cache them.
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor ckv, Tensor k_pe,"
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
|
||||
@ -981,14 +981,14 @@ def reshape_and_cache_flash(
|
||||
|
||||
|
||||
def concat_and_cache_mla(
|
||||
ckv: torch.Tensor,
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla(ckv, k_pe, kv_cache,
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
|
||||
slot_mapping, kv_cache_dtype,
|
||||
scale)
|
||||
|
||||
|
||||
@ -269,7 +269,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor, # For MLA hidden_states_or_cq
|
||||
key: torch.Tensor, # For MLA ckv_normed
|
||||
key: torch.Tensor, # For MLA kv_c_normed
|
||||
value: torch.Tensor, # For MLA k_pe
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
|
||||
@ -52,7 +52,7 @@ class MLAImplCommon(AttentionImpl):
|
||||
|
||||
1. The hidden states (B, H) are projected down into cq (B, Lq) and
|
||||
kv_c_k_pe (B, Lkv+R).
|
||||
2. The kv_c_k_pe is split into ckv (B, Lkv) and k_pe (B, R). cq
|
||||
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
|
||||
and kv_c are normalized.
|
||||
|
||||
#
|
||||
@ -249,7 +249,7 @@ class MLAImplCommon(AttentionImpl):
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
ckv_normed: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
) -> torch.Tensor:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user