diff --git a/CMakeLists.txt b/CMakeLists.txt index dcec854a08721..cda1ffc795d1b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -249,7 +249,6 @@ set(VLLM_EXT_SRC "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" - "csrc/prepare_inputs/advance_step.cu" "csrc/custom_all_reduce.cu" "csrc/torch_bindings.cpp") diff --git a/csrc/ops.h b/csrc/ops.h index 207291eceb169..3e29f0a973dd6 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -145,22 +145,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, - int64_t block_size, torch::Tensor& input_tokens, - torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, - torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, - torch::Tensor& block_tables); - -void advance_step_flashinfer( - int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables, - torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, - torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); - void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu deleted file mode 100644 index 3d5077d9de461..0000000000000 --- a/csrc/prepare_inputs/advance_step.cu +++ /dev/null @@ -1,336 +0,0 @@ -/* - * The goal of this GPU kernel is to advance input tensors on the GPU directly - * PR: https://github.com/vllm-project/vllm/pull/6338 - * Current restrictions: - * 1. Specialized for DraftModelRunner - * 2. Supports flash_attn only - */ - -#include "advance_step.cuh" - -namespace prepare_inputs { - -// -template -__global__ void advance_step_flashattn_kernel( - int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, - int64_t const block_tables_stride) { - int const n_pad = num_seqs - num_queries; - if (n_pad && blockIdx.x == 0) { - // Handle cuda graph padding - int const offset = num_queries; - for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { - input_tokens_ptr[offset + i] = 0; - input_positions_ptr[offset + i] = 0; - slot_mapping_ptr[offset + i] = -1; - } - } - - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x >= num_query_blocks) { - return; - } - - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; - - if (cur_query_id >= num_queries) { - return; - } - - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - - // Update seq_lens - seq_lens_ptr[cur_query_id] = next_seq_len; - // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; - - int const* seq_block_tables_ptr = - block_tables_ptr + block_tables_stride * cur_query_id; - - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; - - int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; - // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; -} - -inline void verify_tensor(std::string const& name, torch::Tensor const& t, - int64_t const size_0, int64_t const size_1, - c10::ScalarType const type) { - bool size_0_cond = true; - if (size_0 != -1) { - size_0_cond = t.size(0) == size_0; - } - - bool size_1_cond = true; - if (size_1 != -1) { - size_1_cond = t.size(1) == size_1; - } - - bool is_contiguous = t.is_contiguous(); - bool same_type = t.dtype() == type; - - bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; - if (!pass) { - TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), - " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), - " is not as expected: shape = [", size_0, ", ", size_1, - "], type = ", type); - } -} - -/// each thread processes a block per query -__global__ void advance_step_flashinfer_kernel( - int num_threads, int num_seqs, int num_queries, int block_size, - long* input_tokens_ptr, long const* sampled_token_ids_ptr, - long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, int64_t const block_tables_stride, - int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { - int const n_pad = num_seqs - num_queries; - if (n_pad && blockIdx.x == 0) { - // Handle cuda graph padding - int const offset = num_queries; - for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { - input_tokens_ptr[offset + i] = 0; - input_positions_ptr[offset + i] = 0; - slot_mapping_ptr[offset + i] = -1; - } - } - int num_query_blocks = div_ceil(num_queries, num_threads); - - if (blockIdx.x < num_query_blocks) { - int cur_query_id = blockIdx.x * num_threads + threadIdx.x; - - if (cur_query_id < num_queries) { - // Update input_tokens - input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; - - int seq_len = seq_lens_ptr[cur_query_id]; - int next_seq_len = seq_len + 1; - int next_input_pos = next_seq_len - 1; - - // Update seq_lens - seq_lens_ptr[cur_query_id] = next_seq_len; - // Update input_positions - input_positions_ptr[cur_query_id] = next_input_pos; - - int const* seq_block_tables_ptr = - block_tables_ptr + block_tables_stride * cur_query_id; - - int block_index = next_input_pos / block_size; - int block_offset = next_input_pos % block_size; - - // Update paged_kv_last_page_len - paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; - - int slot_num = - seq_block_tables_ptr[block_index] * block_size + block_offset; - // Update slot_mapping - slot_mapping_ptr[cur_query_id] = slot_num; - block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); - } - } -} - -__global__ void advance_step_flashinfer_indptr_kernel( - int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, - int* block_table_bound_ptr) { - int idx = blockIdx.x * num_threads + threadIdx.x; - // Update paged_kv_indptr - if (idx == 0) { - paged_kv_indptr_ptr[idx] = 0; - } - if (idx < num_queries) { - int sum = 0; - for (int i = 0; i <= idx; ++i) { - sum += block_table_bound_ptr[i]; - } - paged_kv_indptr_ptr[idx + 1] = sum; - } -} - -__global__ void advance_step_flashinfer_indices_kernel( - int num_seqs, int num_queries, int const* block_tables_ptr, - int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr, - int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { - // note: max_num_blocks_per_seq = block_tables.stride(0) - int tid = blockIdx.x * blockDim.x + threadIdx.x; - - // when cuda graphs are enabled, paged_kv_indptr tensor - // has to be updated for the padded queries - // tid represents a query# for paged_kv_indptr tensor - if (num_queries < tid && tid <= num_seqs) { - paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries]; - } - - // each thread processes a block_ptr in block_tables - // block_tables shape: [num_queries, max_num_blocks_per_seq] - // paged_kv_indices is flattened block_tables. - for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq); - idx += (gridDim.x * blockDim.x)) { - // block_tables-row = paged_kv_indptr[queryNum] - int queryNum = idx / max_num_blocks_per_seq; - int col = idx % max_num_blocks_per_seq; - if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) { - int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col; - int block_tables_idx = queryNum * max_num_blocks_per_seq + col; - paged_kv_indices_ptr[indices_arr_idx] = - block_tables_ptr[block_tables_idx]; - } - } -} - -void advance_step_flashattn(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int - - if (logging) { - printf("advance_step_flashattn:\n"); - printf(" num_seqs = %d\n", num_seqs); - printf(" num_queries = %d\n", num_queries); - printf(" block_size = %d\n", block_size); - } - // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, - at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); - verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); - verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); - - int dev = sampled_token_ids.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - - int blocks; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - - advance_step_flashattn_kernel - <<>>( - num_seqs, num_queries, block_size, - reinterpret_cast(input_tokens.data_ptr()), - reinterpret_cast(sampled_token_ids.data_ptr()), - reinterpret_cast(input_positions.data_ptr()), - reinterpret_cast(seq_lens.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); -} - -void advance_step_flashinfer( - int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables, // type: int - torch::Tensor& paged_kv_indices, // type: int - torch::Tensor& paged_kv_indptr, // type: int - torch::Tensor& paged_kv_last_page_len, // type: int - torch::Tensor& block_table_bound) { // type: int - - if (logging) { - printf("advance_step_flashinfer:\n"); - printf(" num_seqs = %d\n", num_seqs); - printf(" num_queries = %d\n", num_queries); - printf(" block_size = %d\n", block_size); - printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); - } - // Verify all tensors - verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); - // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, - // at::kLong); - verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); - verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); - verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); - verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); - - verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); - verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); - verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, - at::kInt); - - verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); - - int dev = sampled_token_ids.get_device(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); - - int blocks; - int threads; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); - - TORCH_CHECK((blocks * threads > num_queries), - "multi-step: not enough threads to map to num_queries = ", - num_queries, " block_tables.stride(0) = ", block_tables.stride(0), - " blocks = ", blocks, " max_threads = ", threads); - if (logging) { - printf("launching kernels with %d blocks and %d threads\n", blocks, - threads); - } - advance_step_flashinfer_kernel<<>>( - threads, num_seqs, num_queries, block_size, - reinterpret_cast(input_tokens.data_ptr()), - reinterpret_cast(sampled_token_ids.data_ptr()), - reinterpret_cast(input_positions.data_ptr()), - reinterpret_cast(seq_lens.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0), - reinterpret_cast(paged_kv_last_page_len.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); - - advance_step_flashinfer_indptr_kernel<<>>( - threads, num_seqs, num_queries, - reinterpret_cast(paged_kv_indptr.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); - - advance_step_flashinfer_indices_kernel<<>>( - num_seqs, num_queries, - reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0), - reinterpret_cast(paged_kv_indices.data_ptr()), - reinterpret_cast(paged_kv_indptr.data_ptr()), - reinterpret_cast(block_table_bound.data_ptr())); -} - -} // namespace prepare_inputs - -void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, - int64_t block_size, torch::Tensor& input_tokens, - torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, - torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, - torch::Tensor& block_tables) { - prepare_inputs::advance_step_flashattn( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables); -} - -void advance_step_flashinfer( - int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables, - torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, - torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { - prepare_inputs::advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, - paged_kv_indptr, paged_kv_last_page_len, block_table_bound); -} diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh deleted file mode 100644 index f21574681b1ab..0000000000000 --- a/csrc/prepare_inputs/advance_step.cuh +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace prepare_inputs { - -static constexpr int max_threads = 256; -static constexpr bool logging = false; - -constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } - -} // namespace prepare_inputs diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8c207be083d88..a547baec50d6a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -142,25 +142,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); - // prepare_inputs advance_step - ops.def( - "advance_step_flashattn(int num_seqs, int num_queries, int block_size, " - "Tensor! input_tokens, Tensor sampled_token_ids, " - "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, " - "Tensor block_tables) -> ()"); - ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn); - - ops.def( - "advance_step_flashinfer(" - " int num_seqs, int num_queries, int block_size," - " Tensor! input_tokens, Tensor sampled_token_ids," - " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping," - " Tensor block_tables, Tensor! paged_kv_indices," - " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len," - " Tensor! block_table_bounds" - ") -> ()"); - ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a020b171e894a..a318637c5aeba 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -319,38 +319,6 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, repetition_penalties) -def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: - """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, - block_size, input_tokens, - sampled_token_ids, - input_positions, seq_lens, - slot_mapping, block_tables) - - -def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, slot_mapping: torch.Tensor, - block_tables: torch.Tensor, - paged_kv_indices: torch.Tensor, - paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, - block_table_bound: torch.Tensor) -> None: - - return torch.ops._C.advance_step_flashinfer( - num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, block_tables, - paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, - block_table_bound) - - # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( input: torch.Tensor, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2417fe06a6755..d21f07756871a 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -101,11 +101,6 @@ class AttentionBackend(ABC): ) -> None: raise NotImplementedError - def advance_step(self, model_input: "ModelRunnerInputBase", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int) -> None: - raise NotImplementedError - @classmethod def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index bd9bc427728d0..fac3c318a87a0 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -35,8 +35,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @@ -326,79 +325,6 @@ class DifferentialFlashAttentionMetadata(AttentionMetadata): cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class DifferentialFlashAttentionMetadataBuilder( AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ee36fd19e0122..e52480d5c5ce2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -32,8 +32,7 @@ from vllm.vllm_flash_attn import (flash_attn_varlen_func, flash_attn_with_kvcache) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder logger = init_logger(__name__) @@ -309,79 +308,6 @@ class FlashAttentionMetadata(AttentionMetadata): cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 78d8a67e37f8f..208cacec38eb5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -51,8 +51,7 @@ from vllm.utils.flashinfer import use_trtllm_attention logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder class FlashInferBackend(AttentionBackend): @@ -428,7 +427,7 @@ class FlashInferMetadata(AttentionMetadata): query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None - # used for GPU in-place advance_step + # used for GPU operations seq_lens_tensor: Optional[torch.Tensor] = None block_table_bound: Optional[torch.Tensor] = None @@ -587,66 +586,6 @@ class FlashInferMetadata(AttentionMetadata): return None return self - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - # Flashinfer doesn't support speculative decoding + chunked-prefill - # + multi-step scheduling yet. - assert self.decode_query_len == 1 - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens_tensor is not None - - assert num_seqs > 0 - assert num_queries > 0 - assert model_input.attn_metadata is not None - assert sampled_token_ids is not None - - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() - - # Update GPU tensors - ops.advance_step_flashinfer( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=model_input.input_tokens, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, - paged_kv_indices=self.paged_kv_indices, - paged_kv_indptr=self.paged_kv_indptr, - paged_kv_last_page_len=self.paged_kv_last_page_len, - block_table_bound=self.block_table_bound) - class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index a242ac9bbe0b6..f23c096952ce0 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -18,9 +18,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - class FlashMLABackend(MLACommonBackend): @@ -62,16 +59,6 @@ class FlashMLAMetadata(MLACommonMetadata): self.decode_num_splits return decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - raise NotImplementedError( - "advance_step is not implemented for FlashMLA") - class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 52c4a9e7da3de..8ff7f56743230 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -234,8 +234,7 @@ except ImportError: flash_attn_varlen_func = None if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ModelInputForGPUBuilder is_hip = current_platform.is_rocm() @@ -631,90 +630,6 @@ class MLACommonMetadata(AttentionMetadata): is_profile_run=self.is_profile_run) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - - if turn_prefills_into_decodes: - # When Multi-Step is enabled with Chunked-Prefill, prefills and - # decodes are scheduled together. In the first step, all the - # prefills turn into decodes. This update reflects that - # conversion. - assert self.num_decode_tokens + self.num_prefills == num_seqs - self.num_decode_tokens += self.num_prefills - self.num_prefills = 0 - self.num_prefill_tokens = 0 - self.max_prefill_seq_len = 0 - self.max_query_len = 1 - - self.slot_mapping = self.slot_mapping[:num_seqs] - else: - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - self._ops_advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions) - - def _ops_advance_step(self, num_seqs: int, num_queries: int, - block_size: int, input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor) -> None: - # here we use advance_step_flashinfo to update the paged_kv_* tensors - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 820ddcab77d71..e630a6c6de8c4 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -15,8 +15,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import (ModelInputForGPUBuilder) from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that @@ -201,65 +200,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ) return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert not turn_prefills_into_decodes, \ - ("Multi-Step + Chunked-Prefill is not supported for attention-free" - "models. turn_prefills_into_decodes is a " - "Multi-Step + Chunked-Prefill specific parameter.") - - assert self.seq_lens is not None - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - # Update sequences, masking off entries greater than num_queries - device = self.seq_lens_tensor.device - mask = torch.arange(self.seq_lens_tensor.size(0), - device=device) < num_queries - self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype) - if sampled_token_ids is not None: - model_input.input_tokens.masked_scatter_( - mask, sampled_token_ids[:num_queries]) - class PlaceholderAttentionMetadataBuilder( AttentionMetadataBuilder[PlaceholderAttentionMetadata]): diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index a165a786d63d0..a2e9710437d95 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, Type, Union import torch -import vllm._custom_ops as ops import vllm.envs as envs from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, @@ -107,26 +106,6 @@ class AiterMLAMetadata(MLACommonMetadata): return self._cached_decode_metadata - def _ops_advance_step(self, num_seqs: int, num_queries: int, - block_size: int, input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor) -> None: - - ops.advance_step_flashinfer( - num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables, - paged_kv_indices=self.paged_kv_indices, - paged_kv_indptr=self.paged_kv_indptr, - paged_kv_last_page_lens=self.paged_kv_last_page_lens, - block_table_bound=self.block_table_bound) - class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): BLOCK_TABLE_EXTENDER: list[list[int]] = [[]] diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index da3d9ff32830c..63e467f5a7a22 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -4,7 +4,7 @@ import itertools from dataclasses import dataclass from functools import cache -from typing import TYPE_CHECKING, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -23,9 +23,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 @@ -261,69 +258,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): - """ - Update metadata in-place to advance one decode step. - """ - - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with rocm_flash_attn yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") - - # When using cudagraph, the num_seqs is padded to the next captured - # batch sized, but num_queries tracks the actual number of requests in - # the batch. For --enforce-eager mode, num_seqs == num_queries - if num_seqs != num_queries: - assert num_seqs > num_queries - assert self.use_cuda_graph - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.num_decode_tokens == num_seqs - assert self.slot_mapping.shape == (num_seqs, ) - - assert self.seq_lens is not None - assert len(self.seq_lens) == num_seqs - assert self.seq_lens_tensor is not None - assert self.seq_lens_tensor.shape == (num_seqs, ) - assert self.max_query_len == 1 - assert self.max_prefill_seq_len == 0 - assert self.max_decode_seq_len == max(self.seq_lens) - - assert self.query_start_loc is not None - assert self.query_start_loc.shape == (num_queries + 1, ) - assert self.seq_start_loc is not None - assert self.seq_start_loc.shape == (num_seqs + 1, ) - - assert self.context_lens_tensor is not None - assert self.context_lens_tensor.shape == (num_queries, ) - - assert self.block_tables is not None - assert self.block_tables.shape[0] == num_seqs - - # Update query lengths. Note that we update only queries and not seqs, - # since tensors may be padded due to captured cuda graph batch size - for i in range(num_queries): - self.seq_lens[i] += 1 - self.max_decode_seq_len = max(self.seq_lens) - - ops.advance_step_flashattn(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) - class ROCmFlashAttentionMetadataBuilder( CommonMetadataBuilder[ROCmFlashAttentionMetadata]): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a63797e3a46a2..a1c08fa814db4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -762,8 +762,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): has Prefills (if any). The rest of the steps are guaranteed to be all decodes. In this case, we set up the padding as if all the sequences are decodes so we may run all steps except the first step in CUDA graph - mode. The padding is accounted for in the multi-step `advance_step` - family of functions. + mode. Args: num_seqs (int): Number of sequences scheduled to run.