diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 5c3650fa72d17..ca7967c1ab0d2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -16,7 +16,7 @@ PARTITION_SIZE = 512 def main( version: str, num_seqs: int, - context_len: int, + seq_len: int, num_query_heads: int, num_kv_heads: int, head_size: int, @@ -48,12 +48,12 @@ def main( dtype=torch.float, device=device) - context_lens = [context_len for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) + seq_lens = [seq_len for _ in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -77,8 +77,7 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +109,9 @@ def main( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -129,9 +128,9 @@ def main( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -166,7 +165,7 @@ if __name__ == '__main__': choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--seq_len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", @@ -199,7 +198,7 @@ if __name__ == '__main__': main( version=args.version, num_seqs=args.batch_size, - context_len=args.context_len, + seq_len=args.seq_len, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_size=args.head_size, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index f3a5bbfd3098d..8b1b5e098015f 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -115,23 +115,23 @@ __device__ void paged_attention_kernel( const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; + const bool mask = token_idx >= seq_len; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -364,14 +364,14 @@ __device__ void paged_attention_kernel( } else { v_vec = *reinterpret_cast(v_ptr + offset); } - if (block_idx == num_context_blocks - 1) { + if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( const float kv_scale) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( const float kv_scale) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -639,8 +639,8 @@ void paged_attention_v1_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -664,11 +664,11 @@ void paged_attention_v1_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); + int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! @@ -715,8 +715,8 @@ void paged_attention_v1_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -746,9 +746,9 @@ void paged_attention_v1( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { @@ -790,7 +790,7 @@ void paged_attention_v1( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -803,7 +803,7 @@ void paged_attention_v1( exp_sums_ptr, \ max_logits_ptr, \ tmp_out_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< @@ -824,8 +824,8 @@ void paged_attention_v2_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -852,10 +852,10 @@ void paged_attention_v2_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -909,8 +909,8 @@ void paged_attention_v2_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -943,9 +943,9 @@ void paged_attention_v2( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 365bbd5e23728..c1d765be05598 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -70,11 +70,11 @@ template FORCE_INLINE std::pair reduceSoftmaxAlibi(T *data, const int size, const int capacity, const float alibi_slope, const int start_index, - const int context_len) { - data[0] += alibi_slope * (start_index - context_len + 1); + const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); data[i] = qk; max = max >= qk ? max : qk; } @@ -225,7 +225,7 @@ struct paged_attention_v1_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -235,32 +235,32 @@ struct paged_attention_v1_impl { static_assert(BLOCK_SIZE == 16); - int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); const int parallel_work_item_num = omp_get_max_threads(); size_t logits_bytes = - parallel_work_item_num * max_context_len_padded * sizeof(float); + parallel_work_item_num * max_seq_len_padded * sizeof(float); float *logits = (float *)std::aligned_alloc( 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_context_len_padded] + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int context_len = context_lens[seq_idx]; + int seq_len = seq_lens[seq_idx]; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const int last_block_token_num = - context_len - (block_num - 1) * BLOCK_SIZE; + seq_len - (block_num - 1) * BLOCK_SIZE; float *__restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_context_len_padded; + logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { @@ -278,11 +278,11 @@ struct paged_attention_v1_impl { // Compute softmax if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, context_len, + reduceSoftmaxAlibi(thread_block_logits, seq_len, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - context_len); + seq_len); } else { - reduceSoftmax(thread_block_logits, context_len, + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } @@ -340,7 +340,7 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); @@ -348,8 +348,8 @@ template void paged_attention_v1_impl_launcher( torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v1_impl_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - context_lens, max_context_len, alibi_slopes); + seq_lens, max_seq_len, alibi_slopes); #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); @@ -448,7 +448,7 @@ struct paged_attention_v2_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -465,22 +465,22 @@ struct paged_attention_v2_impl { for (int partition_idx = 0; partition_idx < max_num_partitions; ++partition_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= context_len) + if (start_token_idx >= seq_len) continue; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; const bool no_reduce = (partition_num == 1); - const int context_token_num = - (std::min(context_len, start_token_idx + PARTITION_SIZE) - + const int token_num = + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); const int block_num = - (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = - context_token_num - (block_num - 1) * BLOCK_SIZE; + token_num - (block_num - 1) * BLOCK_SIZE; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; @@ -507,10 +507,10 @@ struct paged_attention_v2_impl { std::pair max_and_sum; if (alibi_slopes) { max_and_sum = reduceSoftmaxAlibi( - logits, context_token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, context_len); + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, context_token_num, + max_and_sum = reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } @@ -583,9 +583,9 @@ struct paged_attention_v2_impl { #pragma omp parallel for collapse(2) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -612,9 +612,9 @@ struct paged_attention_v2_impl { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -649,7 +649,7 @@ struct paged_attention_v2_impl { paged_attention_v2_impl::call( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); @@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v2_impl_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, block_size, \ - max_context_len, alibi_slopes); + num_kv_heads, scale, block_tables, seq_lens, block_size, \ + max_seq_len, alibi_slopes); #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); diff --git a/csrc/ops.h b/csrc/ops.h index 8ae052427052f..9541adcb3de88 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -10,9 +10,9 @@ void paged_attention_v1( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seq_lens, int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); @@ -28,9 +28,9 @@ void paged_attention_v2( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seq_lens, int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 9b1f3e30b6dca..84539205e0ae3 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: @@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention( num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] - context_len = int(context_lens[i]) + seq_len = int(seq_lens[i]) keys = [] values = [] - for j in range(context_len): + for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len).int() - alibi_bias = (position_ids - context_len + 1).float() + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -149,13 +149,13 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -186,16 +186,15 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, ) elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -218,9 +217,9 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -255,7 +254,7 @@ def test_paged_attention( key_cache, value_cache, block_tables, - context_lens, + seq_lens, scale, alibi_slopes, ) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 8ab1167384c45..5a5987e2242fa 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -51,12 +51,12 @@ def test_contexted_kv_attention( cache_size = 640 block_size = 32 max_block_per_request = 64 - subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv - num_tokens = sum(subquery_lens) + num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -75,15 +75,15 @@ def test_contexted_kv_attention( num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -92,7 +92,7 @@ def test_contexted_kv_attention( dtype=torch.long), dim=0) for i in range(BS): - for j in range(subquery_lens[i]): + for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + @@ -178,7 +178,7 @@ def test_contexted_kv_attention( value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - subquery_lens, seq_lens) + query_lens, seq_lens) if sliding_window > 0: attn_bias = attn_bias.make_local_attention_from_bottomright( sliding_window) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 7859f0b21812f..e4fea165a4d46 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -58,7 +58,7 @@ def _do_sample( device: str, ): seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -68,12 +68,12 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -421,7 +421,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): "Invalid test case, need seq_group_metadata_list" batch_size = 0 - prompt_lens = [] + seq_lens = [] sampling_params_per_row = [] for sgm in seq_group_metadata_list: sampling_params = sgm.sampling_params @@ -431,7 +431,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): # a prompt seq_group has only one sequence seq_data = next(iter(sgm.seq_data.values())) prompt_len = seq_data.get_prompt_len() - prompt_lens.append(prompt_len) + seq_lens.append(prompt_len) if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in @@ -451,8 +451,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens=prompt_lens if prompt_lens else None, - subquery_lens=prompt_lens if prompt_lens else None, + seq_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else None, device=device, pin_memory=model_runner.pin_memory) # the logits tensor is modified in-place by the sampler @@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str): seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): expected: Optional[List[int]] = None sampling_type = random.randint(0, 3) @@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str): sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) sampler_output = sampler(logits=fake_logits, @@ -575,7 +575,7 @@ def test_sampler_mixed(seed: int, device: str): # Shuffle the batch and resample target_index = list(range(batch_size)) for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, prompt_lens): + expected_tokens, seq_lens): random.Random(seed).shuffle(list_to_shuffle) target_index = torch.tensor(target_index) input_tensor.data = input_tensor.index_select(0, target_index) @@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): ), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 0eb784a9c5ac5..492620cf6e2cf 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -45,7 +45,7 @@ class AsyncLLM: gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -66,7 +66,7 @@ class AsyncLLM: gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, engine_use_ray=True, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 98f2731de9aa3..cc0427633e688 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -34,7 +34,7 @@ def test_assert_enough_kv_space(num_steps: int): list(range(block_size * 2)), ] - final_seq_lens = [ + final_prompt_lens = [ len(prompt + output) + num_steps for prompt, output in zip(prompts, prev_output_tokens) ] @@ -43,7 +43,7 @@ def test_assert_enough_kv_space(num_steps: int): prompts, num_gpu_blocks, block_size, - final_seq_lens, + final_prompt_lens, continuations=prev_output_tokens) assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access @@ -103,17 +103,21 @@ def test_same_output_for_single_step(): [6, 7, 8, 9, 10], ] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] multi_step_execute_model_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) single_step_execute_model_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) @@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): random.randint(0, 1000) for _ in range(random.randint(10, 20)) ] for _ in range(10)] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( @@ -195,7 +199,7 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_seq_lens=final_seq_lens), ) + final_prompt_lens=final_prompt_lens), ) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) @@ -217,7 +221,7 @@ def test_same_output_for_multi_step(): num_gpu_blocks, block_size, continuations=continuations, - final_seq_lens=final_seq_lens)) + final_prompt_lens=final_prompt_lens)) single_step_output.extend( worker.execute_model(**execute_model_data.to_dict(), )) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py index ee4135015713d..e7e2e87f599dd 100644 --- a/tests/spec_decode/test_ngram_worker.py +++ b/tests/spec_decode/test_ngram_worker.py @@ -43,11 +43,13 @@ def test_ngram_algo_correctness_for_single_no_match(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), @@ -110,11 +112,13 @@ def test_ngram_algo_correctness_for_batches_not_match_all(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), @@ -180,11 +184,13 @@ def test_ngram_algo_correctness_for_batches_match_all(): ] proposal_len = 5 - final_seq_lens = [len(prompt) + proposal_len for prompt in prompts] + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] ngram_sampler_output_data = create_execute_model_data( seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens)) proposals = proposer.get_proposals( **ngram_sampler_output_data.to_dict(), diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4f8295d25cf41..87c7d88a80f42 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -144,7 +144,7 @@ def create_seq_group_metadata_from_prompts( prompts: List[List[int]], num_gpu_blocks: int, block_size: int, - final_seq_lens: List[int], + final_prompt_lens: List[int], continuations: Optional[List[List[int]]] = None, seq_ids: Optional[List[int]] = None, ) -> List[SequenceGroupMetadata]: @@ -162,7 +162,7 @@ def create_seq_group_metadata_from_prompts( free_gpu_blocks.pop() for _ in range(round_up_to_next_block(final_len, block_size)) ] - for i, final_len in enumerate(final_seq_lens) + for i, final_len in enumerate(final_prompt_lens) } return [ @@ -251,13 +251,13 @@ def create_batch(batch_size, prev_output_tokens = [[ next(iterator) for _ in range(prev_output_token_len) ] for _ in range(batch_size)] - final_seq_lens = [ + final_prompt_lens = [ len(prompt) + len(prev_output_token) + k + 1 for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] execute_model_data = create_execute_model_data( create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_seq_lens, + block_size, final_prompt_lens, prev_output_tokens, seq_ids), ) return execute_model_data, prompts, prev_output_tokens diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index dbaeb4de18258..179e8d25a341b 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,7 +70,7 @@ def test_logits_processors(seed: int, device: str): return logits seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -81,12 +81,12 @@ def test_logits_processors(seed: int, device: str): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) logits_processor_output = logits_processor( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 56fe6db589f18..e7975d0ef48b9 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_prompt_lens == prompt_lens + seq_len - 1) + selected_token_start_idx += seq_len + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is True - assert torch.allclose(attn_metadata.prompt_lens_tensor, - torch.tensor(prompt_lens, device=device)) - assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.max_prompt_len == max(prompt_lens) + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_seq_len == max(seq_lens) # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( attn_metadata.subquery_start_loc, @@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size): # equivalent to subquery_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len seq_start_loc.append(start_idx) assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - assert attn_metadata.max_context_len is None assert torch.allclose( - attn_metadata.context_lens, - torch.zeros(attn_metadata.context_lens.shape[0], + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device)) @@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size): # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = list(range(seq_len)) seq_data = SequenceData(seq_data) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False - assert attn_metadata.prompt_lens is None - assert attn_metadata.max_prompt_len is None + assert attn_metadata.seq_lens is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_context_len == max(prompt_lens) + assert attn_metadata.max_seq_len == max(seq_lens) assert torch.allclose( - attn_metadata.context_lens[:len(prompt_lens)], - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + attn_metadata.seq_lens_tensor[:len(seq_lens)], + torch.tensor(seq_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -241,14 +239,13 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_prompt_lens) == 0 + assert len(return_seq_lens) == 0 @pytest.fixture @@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_runner.set_block_size(16) # Add prefill requests. - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] prefill_metadata_list = [] decode_metadata_list = [] @@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_batch_size = batch_size - prefill_batch_size for i in range(prefill_batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(prompt_len)) + seq_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(seq_len)) seq_data = SequenceData(prompt_toks) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): else: assert attn_metadata.num_decode_tokens == _get_graph_batch_size( decode_batch_size) - assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3faed5ea85307..b43f646fec88e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -39,17 +39,17 @@ def paged_attention_v1( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_context_len: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, - context_lens, block_size, max_context_len, - alibi_slopes, kv_cache_dtype, kv_scale) + num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, kv_scale) def paged_attention_v2( @@ -63,17 +63,17 @@ def paged_attention_v2( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_context_len: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes, kv_cache_dtype, + block_tables, seq_lens, block_size, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 10b8c19b7499e..fc7501ed5e91f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -223,8 +223,8 @@ class FlashAttentionImpl(AttentionImpl): v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -245,9 +245,9 @@ class FlashAttentionImpl(AttentionImpl): value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], ) @@ -258,8 +258,8 @@ class FlashAttentionImpl(AttentionImpl): key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 3bc436315c3de..c411b3971b8f1 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] class ROCmFlashAttentionImpl(AttentionImpl): @@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.prompt_lens is not None + assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -260,8 +260,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, + prefill_meta.max_seq_len, + prefill_meta.max_seq_len, True, self.scale, ) @@ -274,7 +274,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): query, key, value, - prefill_meta.prompt_lens, + prefill_meta.seq_lens, self.scale, ) else: @@ -284,8 +284,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, ) @@ -303,9 +303,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], ) @@ -317,8 +317,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -334,13 +334,13 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - prompt_lens: List[int], + seq_lens: List[int], scale: float, ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(prompt_lens): - end = start + prompt_len + for _, seq_len in enumerate(seq_lens): + end = start + seq_len out = _naive_masked_attention( query[start:end], key[start:end], @@ -349,7 +349,7 @@ def _naive_attention( ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) - start += prompt_len + start += seq_len return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 55a7ce59ac6e0..f75a279086a26 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - prompt_lens: Optional[List[int]] + seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -136,7 +136,7 @@ class TorchSDPABackendImpl(AttentionImpl): kv_scale) if attn_metadata.is_prompt: - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -147,13 +147,13 @@ class TorchSDPABackendImpl(AttentionImpl): if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) + att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -164,9 +164,9 @@ class TorchSDPABackendImpl(AttentionImpl): output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): - end = start + prompt_len + for seq_len, mask in zip(attn_metadata.seq_lens, + attn_metadata.attn_bias): + end = start + seq_len sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], @@ -189,8 +189,8 @@ class TorchSDPABackendImpl(AttentionImpl): key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + attn_metadata.seq_lens_tensor, + attn_metadata.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -205,13 +205,13 @@ class TorchSDPABackendImpl(AttentionImpl): def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -221,7 +221,7 @@ def _make_alibi_bias( bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) @@ -229,14 +229,14 @@ def _make_alibi_bias( def _make_sliding_window_bias( - prompt_lens: List[int], + seq_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: + for seq_len in seq_lens: tensor = torch.full( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=dtype, fill_value=1, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index dc64ac0bf985d..60f6d43f2eaa4 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -242,9 +241,9 @@ class XFormersImpl(AttentionImpl): value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, ) @@ -257,8 +256,8 @@ class XFormersImpl(AttentionImpl): key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -289,7 +288,7 @@ class XFormersImpl(AttentionImpl): value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -310,7 +309,7 @@ class XFormersImpl(AttentionImpl): if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.prompt_lens) + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -318,7 +317,7 @@ class XFormersImpl(AttentionImpl): else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.prompt_lens) + attn_metadata.seq_lens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -343,8 +342,8 @@ class XFormersImpl(AttentionImpl): # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) start = 0 - for i, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], @@ -354,7 +353,7 @@ class XFormersImpl(AttentionImpl): scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.view_as(original_query[start:end])) - start += prompt_len + start += seq_len return output @@ -362,13 +361,13 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -376,16 +375,16 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = (prompt_len + 7) // 8 * 8 + padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, - prompt_len, + seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :prompt_len].copy_(bias) + )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index c20b94ac8315b..00a0f10c0950b 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,12 +13,11 @@ _PARTITION_SIZE = 512 @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -85,8 +84,8 @@ class PagedAttention: key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, + seq_lens: torch.Tensor, + max_seq_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, @@ -97,7 +96,7 @@ class PagedAttention: block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -106,7 +105,7 @@ class PagedAttention: # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_context_len <= 8192 + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: # Run PagedAttention V1. @@ -118,9 +117,9 @@ class PagedAttention: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -150,9 +149,9 @@ class PagedAttention: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -168,9 +167,9 @@ class PagedAttention: value_cache: torch.Tensor, block_tables: torch.Tensor, subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, - max_subquery_len: int, + max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], ) -> torch.Tensor: @@ -185,9 +184,9 @@ class PagedAttention: block_tables, # subquery_start_loc is (batch_size + 1,) subquery_start_loc[:-1], - prompt_lens_tensor, + seq_lens_tensor, context_lens, - max_subquery_len, + max_query_len, alibi_slopes, sliding_window, ) diff --git a/vllm/config.py b/vllm/config.py index aaa2f60739d55..3bdd3f774bc27 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -63,7 +63,10 @@ class ModelConfig: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. """ @@ -84,6 +87,7 @@ class ModelConfig: quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, skip_tokenizer_init: bool = False, ) -> None: @@ -99,6 +103,11 @@ class ModelConfig: self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + if self.max_context_len_to_capture is not None: + raise ValueError("`max_context_len_to_capture` is deprecated. " + "Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = (max_seq_len_to_capture + or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -190,10 +199,10 @@ class ModelConfig: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_context_len_to_capture is None: - self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) def verify_with_parallel_config( self, @@ -772,8 +781,8 @@ class SpeculativeConfig: max_model_len=None, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, + max_seq_len_to_capture=target_model_config. + max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7637616ae6089..1c8e1079bed58 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,7 +44,8 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False - max_context_len_to_capture: int = 8192 + max_context_len_to_capture: Optional[int] = None + max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" @@ -322,6 +323,14 @@ class EngineArgs: default=EngineArgs.max_context_len_to_capture, help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + '(DEPRECATED. Use --max-seq_len-to-capture instead' + ')') + parser.add_argument('--max-seq_len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' + 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', @@ -492,7 +501,8 @@ class EngineArgs: self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init) + self.max_seq_len_to_capture, self.max_logprobs, + self.skip_tokenizer_init) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b022707794a78..3ed660e183360 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,6 +69,9 @@ class LLM: disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig @@ -90,7 +93,8 @@ class LLM: gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -112,6 +116,7 @@ class LLM: swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d79c99e5d0a45..2de7763605dfc 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1033,8 +1033,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: assert seq_group.is_prompt, ( "Caller should ensure the sequence group is in a prefill stage.") seq_ids = seq_group.seq_ids - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None # prompt has only 1 seq id. assert len(seq_ids) == 1 seq_data = seq_group.seq_data[seq_ids[0]] @@ -1042,7 +1042,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: prompt_tokens = seq_data.prompt_token_ids # +1 because we are looking for a next prompt token. next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + subquery_len + 1, + next_token_index_end = min(computed_len + query_len + 1, len(prompt_tokens)) next_prompt_tokens = prompt_tokens[ next_token_index_start:next_token_index_end] diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 12156b2ba1aa2..9969c45963e9a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -16,17 +16,26 @@ _SEED_0_REPLACEMENT = 3403598558 @dataclass class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + # Sequence ids for the sequence group in a previous step. seq_ids: List[int] sampling_params: SamplingParams # seq_id -> sequence data. seq_data: Dict[int, SequenceData] - # The length of the prompt of the sequence group. None if it is in a decode + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode # stage. - prompt_len: Optional[int] - # The length of the query tokens to compute in the current step. None if it - # is in a decode stage. The length of subquery_len <= prompt_len. - subquery_len: Optional[int] + seq_len: Optional[int] + # The length of new query tokens to compute in the current step. None if it + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. + query_len: Optional[int] # A random number generator for sampling. generator: Optional[torch.Generator] # True if the sequence group is in prefill stage. False if it is in a @@ -46,8 +55,8 @@ class SequenceGroupToSample: if len(self.prompt_logprob_indices) > 0: assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: - assert self.prompt_len is not None - assert self.subquery_len is not None + assert self.seq_len is not None + assert self.query_len is not None class SamplingMetadata: @@ -94,8 +103,8 @@ class SamplingMetadata: @staticmethod def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, pin_memory: bool, ) -> "SamplingMetadata": @@ -104,8 +113,8 @@ class SamplingMetadata: selected_token_indices, categorized_sample_indices, num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, - subquery_lens, device) + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, @@ -137,8 +146,8 @@ class SamplingMetadata: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ SamplingType, List[Tuple[int, int]]], int]: @@ -146,9 +155,9 @@ def _prepare_seq_groups( Args: seq_group_metadata_list: A list of sequence group to batch. - prompt_lens: A list of prompt lens per sequence group. + seq_lens: A list of sequence lens per sequence group. Index of prompt len should match with seq_group_metadata_list. - subquery_lens: A list of query lengths. Prompt lens include the length + query_lens: A list of query lengths. Prompt lens include the length of entire prompt tokens, and it could be shorter. device: A device to use for random number generator, `SequenceGroupToSample.generator`. @@ -189,8 +198,8 @@ def _prepare_seq_groups( is_prompt = seq_group_metadata.is_prompt generator: Optional[torch.Generator] = None # If the current seq group is in decode stage, it is None. - prompt_len: Optional[int] = None - subquery_len: Optional[int] = None + seq_len: Optional[int] = None + query_len: Optional[int] = None prompt_logprob_indices: List[int] = [] sample_indices: List[int] = [] do_sample = seq_group_metadata.do_sample @@ -203,12 +212,12 @@ def _prepare_seq_groups( num_prompts += 1 num_prefill_sample = len(seq_ids) assert num_prefill_sample == 1 - assert subquery_lens is not None and prompt_lens is not None - subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] + assert query_lens is not None and seq_lens is not None + query_len, seq_len = query_lens[i], seq_lens[i] # If we need sampling, exclude num_prefill_sample tokens from # prompt logprob. - prompt_logprob_len = (subquery_len - num_prefill_sample - if do_sample else subquery_len) + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) sample_len = num_prefill_sample if do_sample else 0 else: # Decode @@ -267,8 +276,8 @@ def _prepare_seq_groups( seq_ids=seq_ids, sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, - prompt_len=prompt_len, - subquery_len=subquery_len, + seq_len=seq_len, + query_len=query_len, generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), @@ -367,8 +376,8 @@ class SamplingTensors: and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None prefill_len = len(seq_group.prompt_logprob_indices) temperatures += [temperature] * prefill_len top_ps += [top_p] * prefill_len @@ -397,8 +406,8 @@ class SamplingTensors: if is_prompt: prompt_best_of.append(sampling_params.best_of) - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 34d7d3dffea18..193b021b7a11e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -80,7 +80,7 @@ class CPUModelRunner: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -92,15 +92,15 @@ class CPUModelRunner: seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() - prompt_len = len(prompt_tokens) + seq_len = len(prompt_tokens) - prompt_lens.append(prompt_len) # Prompt token num + seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prompt_len))) + input_positions.extend(list(range(computed_len, seq_len))) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -109,15 +109,15 @@ class CPUModelRunner: # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, prompt_len): + for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -151,19 +151,19 @@ class CPUModelRunner: attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - num_prefills=len(prompt_lens), + seq_lens=seq_lens, + seq_lens_tensor=None, + max_seq_len=None, + num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, decode_metadata=None, - max_context_len=None, - context_lens=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, + return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) def _prepare_decode( @@ -174,7 +174,7 @@ class CPUModelRunner: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: @@ -192,9 +192,9 @@ class CPUModelRunner: position = seq_len - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( + seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -208,7 +208,7 @@ class CPUModelRunner: block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_context_len = max(context_lens) + max_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -219,9 +219,9 @@ class CPUModelRunner: slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -236,14 +236,14 @@ class CPUModelRunner: attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_seq_len=max_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), - max_context_len=max_context_len, num_prefills=0, prefill_metadata=None, decode_metadata=None, - context_lens=context_lens, block_tables=block_tables, kv_cache_dtype=self.kv_cache_dtype, ) @@ -265,20 +265,20 @@ class CPUModelRunner: is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, + (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, pin_memory=False) # Broadcast the metadata. @@ -300,7 +300,7 @@ class CPUModelRunner: sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, - prompt_lens=None, + seq_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0704f5fec54d0..bbb1f5205af5e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -42,8 +42,8 @@ class PreparePromptMetadata(NamedTuple): input_tokens: List[int] input_positions: List[int] attn_metadata: Optional[AttentionMetadataPerStage] - prompt_lens: List[int] - subquery_lens: List[int] + seq_lens: List[int] + query_lens: List[int] lora_index_mapping: List[int] lora_prompt_mapping: List[int] lora_requests: Set[LoRARequest] @@ -56,8 +56,8 @@ class PreparePromptMetadata(NamedTuple): input_tokens=[], input_positions=[], attn_metadata=None, - prompt_lens=[], - subquery_lens=[], + seq_lens=[], + query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), @@ -134,9 +134,8 @@ class ModelRunner: self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.max_context_len_to_capture = ( - self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture + if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -149,7 +148,7 @@ class ModelRunner: self.model: torch.nn.Module # Set after load_model self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in + # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be @@ -218,7 +217,7 @@ class ModelRunner: def get_max_block_per_batch(self) -> int: block_size = self.block_size - return (self.max_context_len_to_capture + block_size - 1) // block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -231,9 +230,9 @@ class ModelRunner: lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - prompt_lens: List[int] = [] + seq_lens: List[int] = [] context_lens: List[int] = [] - subquery_lens: List[int] = [] + query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] @@ -257,21 +256,19 @@ class ModelRunner: token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - computed_len = seq_data.get_num_computed_tokens() + context_len = seq_data.get_num_computed_tokens() # We should use get_len here because in case of preemption # it contains output tokens. - prefill_end = min(seq_data.get_len(), - computed_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = prefill_end - prompt_lens.append(prompt_len) + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[computed_len:] + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: @@ -285,25 +282,25 @@ class ModelRunner: prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 + assert context_len == 0 # actual prompt lens - context_lens.append(computed_len) - subquery_lens.append(prompt_len - computed_len) + context_lens.append(context_len) + query_lens.append(seq_len - context_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prefill_end))) + input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (prompt_len - computed_len) + lora_index_mapping += [lora_id] * (seq_len - context_len) lora_prompt_mapping.extend( [lora_id] * - (prompt_len - computed_len + (seq_len - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: @@ -313,24 +310,24 @@ class ModelRunner: if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert computed_len == 0, ( + assert context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, prefill_end): + for i in range(context_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -340,9 +337,9 @@ class ModelRunner: slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - max_subquery_len = max(subquery_lens) - max_prompt_len = max(prompt_lens) - assert max_subquery_len > 0 + max_query_len = max(query_lens) + max_seq_len = max(seq_lens) + assert max_query_len > 0 context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -369,40 +366,39 @@ class ModelRunner: # Query length can be shorter than key (i.e., prompt) when prefill # is chunked or prefix cached. - subquery_lens_tensor = torch.tensor(subquery_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device=self.device) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - torch.cumsum(subquery_lens_tensor, + torch.cumsum(query_lens_tensor, dim=0, dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, - max_subquery_len=max_subquery_len, - max_context_len=None, - max_prompt_len=max_prompt_len, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_seq_len=max_seq_len, subquery_start_loc=subquery_start_loc, seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, + context_lens_tensor=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, ) @@ -411,8 +407,8 @@ class ModelRunner: input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, - prompt_lens=prompt_lens, - subquery_lens=subquery_lens, + seq_lens=seq_lens, + query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, @@ -427,7 +423,7 @@ class ModelRunner: input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] @@ -455,9 +451,9 @@ class ModelRunner: position = seq_len - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( + seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -477,11 +473,10 @@ class ModelRunner: # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) - max_context_len = max(context_lens) - use_captured_graph = ( - not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + max_seq_len = max(seq_lens) + use_captured_graph = (not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -489,21 +484,21 @@ class ModelRunner: input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens_tensor.shape[0] == len(input_tokens) - assert context_lens_tensor.shape[0] == len(input_positions) - assert context_lens_tensor.shape[0] == len(slot_mapping) + assert seq_lens_tensor.shape[0] == len(input_tokens) + assert seq_lens_tensor.shape[0] == len(input_positions) + assert seq_lens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -525,14 +520,13 @@ class ModelRunner: attn_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=max_context_len, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_query_len=None, + max_seq_len=max_seq_len, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens_tensor, + context_lens_tensor=None, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) @@ -565,8 +559,8 @@ class ModelRunner: input_tokens, input_positions, prefill_attn_metadata, - prompt_lens, - subquery_lens, + seq_lens, + query_lens, lora_index_mapping, lora_prompt_mapping, lora_requests, @@ -583,13 +577,13 @@ class ModelRunner: decode_slot_mapping, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, prompt_lens, subquery_lens, - self.device, self.pin_memory) + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 - num_prefills = len(prompt_lens) + num_prefills = len(seq_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) @@ -886,7 +880,7 @@ class ModelRunner: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -908,14 +902,13 @@ class ModelRunner: # Create dummy attn_metadata. decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=self.max_context_len_to_capture, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_seq_len=self.max_seq_len_to_capture, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens[:batch_size], + context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) @@ -1025,7 +1018,7 @@ class CUDAGraphRunner: "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.decode_metadata.context_lens, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} @@ -1047,8 +1040,8 @@ class CUDAGraphRunner: self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_( - attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a974e85c22f45..a336be04e124f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -52,7 +52,7 @@ class NeuronModelRunner: input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -61,26 +61,26 @@ class NeuronModelRunner: seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_positions.append(list(range(seq_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_prompt_len = max(prompt_lens) - assert max_prompt_len > 0 + max_seq_len = max(seq_lens) + assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) @@ -88,7 +88,7 @@ class NeuronModelRunner: dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, prompt_lens + return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, @@ -149,18 +149,18 @@ class NeuronModelRunner: # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, self.pin_memory)