mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 05:42:15 +08:00
simplify - get rid of tokenshape
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
230730c34d
commit
c33aeecf24
@ -28,12 +28,6 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
|||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||||
|
|
||||||
void reshape_and_cache_flash_full_cuda(
|
|
||||||
torch::Tensor& tokenshape, torch::Tensor& key, torch::Tensor& value,
|
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
|
||||||
torch::Tensor& slot_mapping, const std::string& kv_cache_dtype,
|
|
||||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
|
||||||
|
|
||||||
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||||
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
|
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
|
|||||||
@ -434,50 +434,6 @@ void reshape_and_cache_flash(
|
|||||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV_T is the stored data type of kv-cache.
|
|
||||||
// CACHE_T is the data type of key and value tensors.
|
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
|
||||||
#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE) \
|
|
||||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
|
||||||
<<<grid, block, 0, stream>>>( \
|
|
||||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
|
||||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
|
||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
|
||||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
|
||||||
value_stride, num_heads, head_size, block_size, \
|
|
||||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
|
||||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
|
||||||
|
|
||||||
void reshape_and_cache_flash_full_cuda(
|
|
||||||
torch::Tensor& tokenshape, // true num_tokens at first entry.
|
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
|
||||||
torch::Tensor&
|
|
||||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
|
||||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
|
||||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
|
||||||
torch::Tensor& v_scale) {
|
|
||||||
int padded_num_tokens = slot_mapping.size(0);
|
|
||||||
int num_heads = key.size(1);
|
|
||||||
int head_size = key.size(2);
|
|
||||||
int block_size = key_cache.size(1);
|
|
||||||
|
|
||||||
int key_stride = key.stride(0);
|
|
||||||
int value_stride = value.stride(0);
|
|
||||||
int block_stride = key_cache.stride(0);
|
|
||||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
|
||||||
|
|
||||||
dim3 grid(padded_num_tokens);
|
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
|
||||||
CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
<<<grid, block, 0, stream>>>( \
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
|||||||
@ -470,18 +470,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||||
&reshape_and_cache_flash);
|
&reshape_and_cache_flash);
|
||||||
|
|
||||||
// Reshape the key and value tensors and cache them.
|
|
||||||
cache_ops.def(
|
|
||||||
"reshape_and_cache_flash_full_cuda(Tensor tensorshape,"
|
|
||||||
" Tensor key, Tensor value,"
|
|
||||||
" Tensor! key_cache,"
|
|
||||||
" Tensor! value_cache,"
|
|
||||||
" Tensor slot_mapping,"
|
|
||||||
" str kv_cache_dtype,"
|
|
||||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
|
||||||
cache_ops.impl("reshape_and_cache_flash_full_cuda", torch::kCUDA,
|
|
||||||
&reshape_and_cache_flash_full_cuda);
|
|
||||||
|
|
||||||
// Concat kv_c and k_pe and cache them.
|
// Concat kv_c and k_pe and cache them.
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||||
|
|||||||
@ -75,9 +75,6 @@ class FlashAttentionMetadata:
|
|||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
# [num_actual_tokens, batch_size, max_query_len, max_seq_len]
|
|
||||||
tokenshape: torch.Tensor
|
|
||||||
|
|
||||||
# For cascade attention.
|
# For cascade attention.
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
common_prefix_len: int
|
common_prefix_len: int
|
||||||
@ -194,17 +191,15 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Whenever making a change in this method, please benchmark the
|
# Whenever making a change in this method, please benchmark the
|
||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
tokenshape = attn_metadata.tokenshape
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
num_padded_tokens = key.shape[0]
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash_full_cuda(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
tokenshape,
|
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping[:num_padded_tokens],
|
attn_metadata.slot_mapping[:num_actual_tokens],
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
layer._k_scale,
|
layer._k_scale,
|
||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
@ -213,16 +208,15 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
if not attn_metadata.use_cascade:
|
if not attn_metadata.use_cascade:
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
||||||
batch_size = attn_metadata.block_table.shape[0]
|
batch_size = attn_metadata.block_table.shape[0]
|
||||||
|
|
||||||
#TODO: Do we need to slice by [:batch_size+1]?
|
#TODO: Do we need to slice by [:batch_size+1]?
|
||||||
flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query[:num_padded_tokens],
|
q=query[:num_actual_tokens],
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
out=output[:num_padded_tokens],
|
out=output[:num_actual_tokens],
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size + 1],
|
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1],
|
||||||
max_seqlen_q=attn_metadata.max_query_len,
|
max_seqlen_q=attn_metadata.max_query_len,
|
||||||
seqused_k=attn_metadata.seq_lens[:batch_size],
|
seqused_k=attn_metadata.seq_lens[:batch_size],
|
||||||
max_seqlen_k=attn_metadata.max_seq_len,
|
max_seqlen_k=attn_metadata.max_seq_len,
|
||||||
|
|||||||
@ -185,7 +185,6 @@ class GPUModelRunner:
|
|||||||
# this one must be int64
|
# this one must be int64
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
self.tokenshape = torch.zeros(4, dtype=torch.int32, device=self.device)
|
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||||
@ -221,11 +220,6 @@ class GPUModelRunner:
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
self.tokenshape_cpu = torch.zeros(4,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
# Remove stopped requests from the cached states.
|
# Remove stopped requests from the cached states.
|
||||||
# Keep the states of the preempted requests.
|
# Keep the states of the preempted requests.
|
||||||
@ -466,13 +460,9 @@ class GPUModelRunner:
|
|||||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
self.tokenshape_cpu[
|
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||||
0] = total_num_scheduled_tokens # Tokens to process
|
self.positions[total_num_scheduled_tokens:].fill_(0)
|
||||||
self.tokenshape_cpu[1] = num_reqs # Number of requests
|
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||||
self.tokenshape_cpu[
|
|
||||||
2] = max_num_scheduled_tokens # Maximum query length
|
|
||||||
self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length
|
|
||||||
self.tokenshape.copy_(self.tokenshape_cpu, non_blocking=True)
|
|
||||||
|
|
||||||
# Prepare for cascade attention if needed.
|
# Prepare for cascade attention if needed.
|
||||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
||||||
@ -561,7 +551,6 @@ class GPUModelRunner:
|
|||||||
block_table=(
|
block_table=(
|
||||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=self.slot_mapping,
|
||||||
tokenshape=self.tokenshape,
|
|
||||||
# Cascade stuff
|
# Cascade stuff
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
@ -926,7 +915,6 @@ class GPUModelRunner:
|
|||||||
block_table=(
|
block_table=(
|
||||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=self.slot_mapping,
|
||||||
tokenshape=self.tokenshape,
|
|
||||||
# Cascade stuff. Non-piecewise CUDA graphs NYI
|
# Cascade stuff. Non-piecewise CUDA graphs NYI
|
||||||
use_cascade=False,
|
use_cascade=False,
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user