mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 18:47:09 +08:00
[Kernel] cuda kernels for upcoming decode context parallel feature (#23791)
Co-authored-by: hongchao <hongchao@msh.team>
This commit is contained in:
parent
daa1273b14
commit
186aced5ff
17
csrc/cache.h
17
csrc/cache.h
@ -36,6 +36,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
|||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
|
||||||
|
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||||
|
torch::Tensor& cp_local_token_select_indices,
|
||||||
|
torch::Tensor& kv_cache,
|
||||||
|
torch::Tensor& slot_mapping,
|
||||||
|
const std::string& kv_cache_dtype,
|
||||||
|
torch::Tensor& scale);
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const double scale, const std::string& kv_cache_dtype);
|
const double scale, const std::string& kv_cache_dtype);
|
||||||
@ -47,4 +54,12 @@ void gather_and_maybe_dequant_cache(
|
|||||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||||
torch::Tensor const& scale,
|
torch::Tensor const& scale,
|
||||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||||
|
|
||||||
|
// TODO(hc): cp_gather_cache need support scaled kvcahe in the future.
|
||||||
|
void cp_gather_cache(
|
||||||
|
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||||
|
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||||
|
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||||
|
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||||
|
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <c10/cuda/CUDAException.h>
|
||||||
|
|
||||||
#include "cuda_utils.h"
|
#include "cuda_utils.h"
|
||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
@ -395,6 +396,51 @@ __global__ void concat_and_cache_mla_kernel(
|
|||||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
|
__global__ void cp_fused_concat_and_cache_mla_kernel(
|
||||||
|
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
|
||||||
|
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
|
||||||
|
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
|
||||||
|
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||||
|
// + pe_dim)]
|
||||||
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
|
const int block_stride, //
|
||||||
|
const int entry_stride, //
|
||||||
|
const int kv_c_stride, //
|
||||||
|
const int k_pe_stride, //
|
||||||
|
const int kv_lora_rank, //
|
||||||
|
const int pe_dim, //
|
||||||
|
const int block_size, //
|
||||||
|
const float* scale //
|
||||||
|
) {
|
||||||
|
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
|
||||||
|
const int64_t slot_idx = slot_mapping[blockIdx.x];
|
||||||
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
|
if (slot_idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t block_idx = slot_idx / block_size;
|
||||||
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
|
||||||
|
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
|
||||||
|
int src_stride, int dst_stride, int size, int offset) {
|
||||||
|
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||||
|
const int64_t src_idx = token_idx * src_stride + i;
|
||||||
|
const int64_t dst_idx =
|
||||||
|
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
|
dst[dst_idx] = src[src_idx];
|
||||||
|
} else {
|
||||||
|
dst[dst_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||||
|
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// KV_T is the data type of key and value tensors.
|
// KV_T is the data type of key and value tensors.
|
||||||
@ -508,6 +554,20 @@ void reshape_and_cache_flash(
|
|||||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||||
|
|
||||||
|
// KV_T is the data type of key and value tensors.
|
||||||
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
|
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
|
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||||
|
cp_local_token_select_indices.data_ptr<int64_t>(), \
|
||||||
|
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||||
|
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||||
|
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||||
|
|
||||||
void concat_and_cache_mla(
|
void concat_and_cache_mla(
|
||||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||||
@ -546,6 +606,50 @@ void concat_and_cache_mla(
|
|||||||
CALL_CONCAT_AND_CACHE_MLA);
|
CALL_CONCAT_AND_CACHE_MLA);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
|
||||||
|
// calls into one:
|
||||||
|
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
|
||||||
|
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
|
||||||
|
// concat_and_cache_mla.
|
||||||
|
void cp_fused_concat_and_cache_mla(
|
||||||
|
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
|
||||||
|
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
|
||||||
|
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
|
||||||
|
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||||
|
// pe_dim)]
|
||||||
|
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||||
|
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||||
|
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||||
|
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||||
|
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||||
|
// both include padding.
|
||||||
|
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
||||||
|
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
||||||
|
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
||||||
|
// before padding.
|
||||||
|
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
||||||
|
// number of tokens.
|
||||||
|
int num_tokens = slot_mapping.size(0);
|
||||||
|
int kv_lora_rank = kv_c.size(1);
|
||||||
|
int pe_dim = k_pe.size(1);
|
||||||
|
int block_size = kv_cache.size(1);
|
||||||
|
|
||||||
|
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||||
|
|
||||||
|
int kv_c_stride = kv_c.stride(0);
|
||||||
|
int k_pe_stride = k_pe.stride(0);
|
||||||
|
int block_stride = kv_cache.stride(0);
|
||||||
|
int entry_stride = kv_cache.stride(1);
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(kv_lora_rank, 512));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||||
|
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
|
||||||
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
@ -779,3 +883,146 @@ void gather_and_maybe_dequant_cache(
|
|||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
template <typename scalar_t>
|
||||||
|
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
|
||||||
|
// block_size.
|
||||||
|
__global__ void cp_gather_cache(
|
||||||
|
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||||
|
// ENTRY_SIZE]
|
||||||
|
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRY_SIZE]
|
||||||
|
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||||
|
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
|
||||||
|
const int32_t block_size, const int32_t entry_size,
|
||||||
|
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||||
|
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||||
|
const int32_t* __restrict__ seq_starts // Optional: starting offsets per
|
||||||
|
// batch
|
||||||
|
) {
|
||||||
|
const int64_t bid = blockIdx.x; // Batch ID
|
||||||
|
const int32_t num_splits = gridDim.y;
|
||||||
|
const int32_t split = blockIdx.y;
|
||||||
|
const int32_t seq_start = cu_seq_lens[bid];
|
||||||
|
const int32_t seq_end = cu_seq_lens[bid + 1];
|
||||||
|
const int32_t seq_len = seq_end - seq_start;
|
||||||
|
const int32_t tot_slots = seq_len;
|
||||||
|
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||||
|
|
||||||
|
const int32_t split_start = split * split_slots;
|
||||||
|
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||||
|
|
||||||
|
const bool is_active_split = (split_start < tot_slots);
|
||||||
|
const bool is_last_split = (split_end == tot_slots);
|
||||||
|
|
||||||
|
if (!is_active_split) return;
|
||||||
|
|
||||||
|
// Adjust the pointer for the block_table for this batch.
|
||||||
|
// If seq_starts is provided, compute an offset based on it
|
||||||
|
const int32_t batch_offset = bid * block_table_stride;
|
||||||
|
int32_t offset = split_start;
|
||||||
|
if (seq_starts != nullptr) {
|
||||||
|
offset += seq_starts[bid];
|
||||||
|
}
|
||||||
|
int32_t offset_div = offset / block_size;
|
||||||
|
offset = offset % block_size;
|
||||||
|
const int32_t* batch_block_table = block_table + batch_offset;
|
||||||
|
|
||||||
|
// Adjust dst pointer based on the cumulative sequence lengths.
|
||||||
|
dst += seq_start * dst_entry_stride;
|
||||||
|
|
||||||
|
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||||
|
scalar_t* __restrict__ _dst) {
|
||||||
|
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||||
|
_dst[i] = _src[i];
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int pid = split_start; pid < split_end; ++pid) {
|
||||||
|
auto block_id = batch_block_table[offset_div];
|
||||||
|
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||||
|
auto block_dst_ptr = dst + pid * dst_entry_stride;
|
||||||
|
copy_entry(block_start_ptr + offset * cache_entry_stride, block_dst_ptr);
|
||||||
|
offset += 1;
|
||||||
|
// bump to next block
|
||||||
|
if (offset == block_size) {
|
||||||
|
offset_div += 1;
|
||||||
|
offset = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// Macro to dispatch the kernel based on the data type.
|
||||||
|
#define CALL_CP_GATHER_CACHE(CPY_DTYPE) \
|
||||||
|
vllm::cp_gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||||
|
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||||
|
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||||
|
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||||
|
|
||||||
|
// Gather sequences from the cache into the destination tensor.
|
||||||
|
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||||
|
// - block_table contains the cache block indices for each sequence
|
||||||
|
// - Optionally, seq_starts (if provided) offsets the starting slot index by
|
||||||
|
// seq_starts[bid]
|
||||||
|
void cp_gather_cache(
|
||||||
|
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||||
|
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||||
|
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||||
|
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||||
|
int64_t batch_size,
|
||||||
|
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||||
|
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
int32_t block_size = src_cache.size(1);
|
||||||
|
int32_t entry_size = src_cache.flatten(2, -1).size(2);
|
||||||
|
|
||||||
|
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||||
|
"block_table must be int32");
|
||||||
|
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
|
||||||
|
"cu_seq_lens must be int32");
|
||||||
|
if (seq_starts.has_value()) {
|
||||||
|
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
||||||
|
"seq_starts must be int32");
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||||
|
"src_cache and dst must be on the same device");
|
||||||
|
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||||
|
"src_cache and block_table must be on the same device");
|
||||||
|
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
|
||||||
|
"src_cache and cu_seq_lens must be on the same device");
|
||||||
|
if (seq_starts.has_value()) {
|
||||||
|
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
|
||||||
|
"src_cache and seq_starts must be on the same device");
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t block_table_stride = block_table.stride(0);
|
||||||
|
int64_t cache_block_stride = src_cache.stride(0);
|
||||||
|
int64_t cache_entry_stride = src_cache.stride(1);
|
||||||
|
int64_t dst_entry_stride = dst.stride(0);
|
||||||
|
|
||||||
|
// Decide on the number of splits based on the batch size.
|
||||||
|
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||||
|
dim3 grid(batch_size, num_splits);
|
||||||
|
dim3 block(1024);
|
||||||
|
|
||||||
|
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||||
|
"src_cache and dst must have the same dtype");
|
||||||
|
|
||||||
|
const int dtype_bits = src_cache.element_size() * 8;
|
||||||
|
const int32_t* seq_starts_ptr =
|
||||||
|
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||||
|
|
||||||
|
if (dtype_bits == 32) {
|
||||||
|
CALL_CP_GATHER_CACHE(uint32_t);
|
||||||
|
} else if (dtype_bits == 16) {
|
||||||
|
CALL_CP_GATHER_CACHE(uint16_t);
|
||||||
|
} else if (dtype_bits == 8) {
|
||||||
|
CALL_CP_GATHER_CACHE(uint8_t);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -686,6 +686,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor scale) -> ()");
|
" Tensor scale) -> ()");
|
||||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||||
|
|
||||||
|
cache_ops.def(
|
||||||
|
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||||
|
" Tensor cp_local_token_select_indices,"
|
||||||
|
" Tensor! kv_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype,"
|
||||||
|
" Tensor scale) -> ()");
|
||||||
|
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
|
||||||
|
&cp_fused_concat_and_cache_mla);
|
||||||
|
|
||||||
// Convert the key and value cache to fp8 data type.
|
// Convert the key and value cache to fp8 data type.
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||||
@ -702,6 +712,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor scale, Tensor? seq_starts) -> ()");
|
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||||
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||||
&gather_and_maybe_dequant_cache);
|
&gather_and_maybe_dequant_cache);
|
||||||
|
|
||||||
|
cache_ops.def(
|
||||||
|
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||||
|
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||||
|
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||||
|
|||||||
@ -790,6 +790,78 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
|||||||
torch.testing.assert_close(dst, expected)
|
torch.testing.assert_close(dst, expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||||
|
@pytest.mark.parametrize("qk_rope_head_dim", [64])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("num_blocks", [1024])
|
||||||
|
@pytest.mark.parametrize("max_seq_len", [512])
|
||||||
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||||
|
@pytest.mark.parametrize("kv_cache_dtype",
|
||||||
|
["auto"]) # You can also test "fp8" if needed.
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||||
|
num_blocks, max_seq_len, batch_size, dtype,
|
||||||
|
kv_cache_dtype, device):
|
||||||
|
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||||
|
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||||
|
kv_cache_dtype, device)
|
||||||
|
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||||
|
|
||||||
|
seq_len_tensor = torch.randint(0,
|
||||||
|
max_seq_len + 1, (batch_size, ),
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
total_tokens = seq_len_tensor.sum()
|
||||||
|
cu_seq_lens = torch.empty((batch_size + 1),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
cu_seq_lens[0] = 0
|
||||||
|
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||||
|
print("seq_len_tensor", seq_len_tensor)
|
||||||
|
|
||||||
|
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||||
|
block_table = torch.empty((batch_size, num_blocks),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
perm = torch.randperm(num_blocks, device=device)
|
||||||
|
block_table[b, :] = perm
|
||||||
|
|
||||||
|
dst = torch.zeros((total_tokens, entry_size),
|
||||||
|
dtype=src_cache.dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
expected_batches = []
|
||||||
|
for b in range(batch_size):
|
||||||
|
s = seq_len_tensor[b]
|
||||||
|
if s == 0:
|
||||||
|
continue
|
||||||
|
tot = tot_blocks_tensor[b]
|
||||||
|
blocks = block_table[b, :tot].tolist()
|
||||||
|
|
||||||
|
gathered_rows = []
|
||||||
|
for i in range(tot - 1):
|
||||||
|
gathered_rows.append(src_cache[blocks[i]])
|
||||||
|
remaining = s - (tot - 1) * block_size
|
||||||
|
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||||
|
|
||||||
|
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||||
|
expected_batches.append(batch_expected)
|
||||||
|
expected = torch.cat(expected_batches, dim=0)
|
||||||
|
|
||||||
|
opcheck(
|
||||||
|
torch.ops._C_cache_ops.cp_gather_cache,
|
||||||
|
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||||
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.cp_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||||
|
torch.testing.assert_close(dst, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||||
|
|||||||
@ -1625,6 +1625,20 @@ def concat_and_cache_mla(
|
|||||||
scale)
|
scale)
|
||||||
|
|
||||||
|
|
||||||
|
def cp_fused_concat_and_cache_mla(
|
||||||
|
kv_c: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
cp_local_token_select_indices: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
scale: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops._C_cache_ops.cp_fused_concat_and_cache_mla(
|
||||||
|
kv_c, k_pe, cp_local_token_select_indices, kv_cache, slot_mapping,
|
||||||
|
kv_cache_dtype, scale)
|
||||||
|
|
||||||
|
|
||||||
def copy_blocks(key_caches: list[torch.Tensor],
|
def copy_blocks(key_caches: list[torch.Tensor],
|
||||||
value_caches: list[torch.Tensor],
|
value_caches: list[torch.Tensor],
|
||||||
block_mapping: torch.Tensor) -> None:
|
block_mapping: torch.Tensor) -> None:
|
||||||
@ -1662,6 +1676,16 @@ def gather_and_maybe_dequant_cache(
|
|||||||
scale, seq_starts)
|
scale, seq_starts)
|
||||||
|
|
||||||
|
|
||||||
|
def cp_gather_cache(src_cache: torch.Tensor,
|
||||||
|
dst: torch.Tensor,
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
cu_seq_lens: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
seq_starts: Optional[torch.Tensor] = None) -> None:
|
||||||
|
torch.ops._C_cache_ops.cp_gather_cache(src_cache, dst, block_table,
|
||||||
|
cu_seq_lens, batch_size, seq_starts)
|
||||||
|
|
||||||
|
|
||||||
def get_device_attribute(attribute: int, device: int) -> int:
|
def get_device_attribute(attribute: int, device: int) -> int:
|
||||||
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user