From 22771e5d83e6550d16f6e85bbaa6c0a8b55a5c0f Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 31 Aug 2025 20:41:38 -0700 Subject: [PATCH] work Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_block_table.py | 2 +- vllm/v1/worker/gpu_input_batch.py | 108 --------------- vllm/v1/worker/gpu_model_runner.py | 199 ++++++++++++---------------- vllm/v1/worker/gpu_worker_states.py | 134 ++++++++++++++++++- 4 files changed, 219 insertions(+), 224 deletions(-) diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py index f828ea5009c59..01700db523887 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu_block_table.py @@ -156,8 +156,8 @@ class BlockTables: self, cu_num_tokens: torch.Tensor, pos: torch.Tensor, + num_tokens: int, ) -> tuple[torch.Tensor, ...]: - num_tokens = pos.shape[0] num_reqs = cu_num_tokens.shape[0] - 1 num_groups = self.num_kv_cache_groups _compute_slot_mappings_kernel[(num_reqs + 1, num_groups)]( diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index f0b2eae545ce6..c91236f63e9e0 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,10 +3,8 @@ from dataclasses import dataclass from typing import Any, Optional -import numba import numpy as np import torch -from numba import types from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -35,109 +33,3 @@ class InputBatch: spec_decode_metadata: Optional[SpecDecodeMetadata] logits_indices: torch.Tensor - - -# NOTE: With the type annotations, this function is pre-compiled -# before the first call. -@numba.jit( - [ - types.none( - types.int32[:], # idx_mapping - types.int32[:, :], # token_ids - types.int32[:], # num_computed_tokens - types.int32[:], # num_scheduled_tokens - types.int32[:], # input_ids - types.int32[:], # query_start_loc - types.int32[:], # seq_lens - types.int64[:], # positions - ) - ], - nopython=True, - cache=True, -) -def prepare_inputs( - # Inputs - idx_mapping: np.ndarray, # batch_idx -> req_idx - token_ids: np.ndarray, # [N, max_model_len] - num_computed_tokens: np.ndarray, # [N] - num_scheduled_tokens: np.ndarray, # [B] - # Outputs - input_ids: np.ndarray, # [num_input_tokens] - query_start_loc: np.ndarray, # [B + 1] - seq_lens: np.ndarray, # [B] - positions: np.ndarray, # [num_input_tokens] -) -> None: - num_reqs = num_scheduled_tokens.shape[0] - query_start_loc[0] = 0 - - cu_num_tokens = 0 - for i in range(num_reqs): - req_idx = idx_mapping[i] - start = num_computed_tokens[req_idx] - end = start + num_scheduled_tokens[i] - seq_lens[i] = end - - start_idx = cu_num_tokens - end_idx = start_idx + num_scheduled_tokens[i] - input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] - positions[start_idx:end_idx] = np.arange(start, end) - - cu_num_tokens = end_idx - query_start_loc[i + 1] = cu_num_tokens - - # Pad the inputs for CUDA graphs. - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - query_start_loc[num_reqs + 1:].fill(cu_num_tokens) - # Fill unused with 0 for full cuda graph mode. - seq_lens[num_reqs:].fill(0) - - -def prepare_spec_decode( - # Inputs - query_start_loc: np.ndarray, # [B + 1] - num_draft_tokens: np.ndarray, # [B] - # Outputs - cu_num_draft_tokens: np.ndarray, # [B] - logits_indices: np.ndarray, # [N + B] - target_logits_indices: np.ndarray, # [N] - bonus_logits_indices: np.ndarray, # [B] -) -> int: # N - # Inputs: - # query_start_loc: [ 0, 4, 104, 107, 207, 209] - # num_draft_tokens: [ 3, 0, 2, 0, 1] - # Outputs: - # cu_num_draft_tokens: [ 3, 3, 5, 5, 6] - # logits_indices: [ 0, 1, 2, 3, 103, 104, 105, 106, - # 206, 207, 208] - # target_logits_indices: [ 0, 1, 2, 5, 6, 9] - # bonus_logits_indices: [ 3, 4, 7, 8, 10] - # return: 6 (total number of draft tokens) - - cu_num_draft = 0 - cu_num_sample = 0 - num_reqs = num_draft_tokens.shape[0] - for i in range(num_reqs): - q_end_idx = query_start_loc[i + 1] - draft_len = num_draft_tokens[i] - - # The last draft_len + 1 query tokens are used for sampling. - sample_len = draft_len + 1 - sample_start_idx = cu_num_sample - sample_end_idx = sample_start_idx + sample_len - logits_indices[sample_start_idx:sample_end_idx] = (np.arange( - q_end_idx - sample_len, q_end_idx)) - - # For each query, the first draft_len tokens need target logits for - # rejection sampling. The draft_len + 1th token is used for bonus token. - draft_start_idx = cu_num_draft - draft_end_idx = draft_start_idx + draft_len - target_logits_indices[draft_start_idx:draft_end_idx] = (np.arange( - sample_start_idx, sample_end_idx - 1)) - bonus_logits_indices[i] = sample_end_idx - 1 - - cu_num_draft += draft_len - cu_num_draft_tokens[i] = cu_num_draft - cu_num_sample += sample_len - - return cu_num_draft diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbe26a38960ce..ab388384208b4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -77,9 +77,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_block_table import BlockTables -from vllm.v1.worker.gpu_input_batch import (InputBatch, prepare_inputs, - prepare_spec_decode) -from vllm.v1.worker.gpu_worker_states import RequestState +from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm.v1.worker.gpu_worker_states import RequestState, prepare_inputs from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -233,24 +232,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, dtype=torch.int32) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) self.cu_num_draft_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) - self.spec_logits_indices = self._make_buffer(self.max_num_tokens + - self.max_num_reqs, - dtype=torch.int32) - self.target_logits_indices = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.bonus_logits_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -543,8 +535,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # batch_idx -> req_id req_ids = sorted(scheduler_output.num_scheduled_tokens, key=scheduler_output.num_scheduled_tokens.get) - # req_id -> batch_idx - req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} # batch_idx -> req_idx idx_mapping_list = [ self.requests.req_id_to_index[req_id] for req_id in req_ids @@ -552,49 +542,50 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.idx_mapping.np[:num_reqs] = idx_mapping_list idx_mapping_np = self.idx_mapping.np[:num_reqs] idx_mapping = self.idx_mapping.copy_to_gpu(num_reqs) + # req_id -> batch_idx + req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. block_tables = self.block_tables.compute_block_tables(idx_mapping) # Get the number of scheduled tokens for each request. - tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] - num_scheduled_tokens = np.array(tokens, dtype=np.int32) - max_num_scheduled_tokens = max(tokens) + num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], + dtype=np.int32) prepare_inputs( - idx_mapping=idx_mapping_np, - token_ids=self.requests.token_ids.np, - num_computed_tokens=self.requests.num_computed_tokens.np, - num_scheduled_tokens=num_scheduled_tokens, - input_ids=self.input_ids.np, - query_start_loc=self.query_start_loc.np, - seq_lens=self.seq_lens.np, - positions=self.positions.np, + idx_mapping_np, + self.requests.token_ids.np, + self.requests.num_computed_tokens.np, + num_scheduled_tokens, + self.input_ids.np, + self.query_start_loc.np, + self.seq_lens.np, + self.positions.np, ) - # Calculate M-RoPE positions. - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._calc_mrope_positions(scheduler_output) - - # Prepare the attention metadata. - self.query_start_loc.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] - - self.seq_lens.copy_to_gpu() - seq_lens = self.seq_lens.gpu[:num_reqs] - max_seq_len = self.seq_lens.np[:num_reqs].max().item() - - # Copy the tensors to the GPU. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) + self.positions.copy_to_gpu(total_num_scheduled_tokens) + + # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens + # tensors from CPU to GPU, because they may include paddings needed + # for full CUDA graph mode. + self.query_start_loc.copy_to_gpu() + self.seq_lens.copy_to_gpu() + query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + max_query_len = int(num_scheduled_tokens.max()) + seq_lens = self.seq_lens.gpu[:num_reqs] + max_seq_len = int(self.seq_lens.np[:num_reqs].max()) + + # Compute the slot mappings on GPUs. + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc, self.positions.gpu, total_num_scheduled_tokens) + if self.uses_mrope: - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) - else: - # Common case (1D positions) - self.positions.copy_to_gpu(total_num_scheduled_tokens) + self._calc_mrope_positions(req_ids, num_scheduled_tokens) + # Optimization: To avoid gather and scatter, copy the whole M-RoPE + # tensor from CPU to GPU although only a part of it is used. + self.mrope_positions.copy_to_gpu() use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -603,19 +594,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # partial requests. While we should not sample any token # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for i, req_id in enumerate(req_ids): - draft_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) - if draft_token_ids: - num_draft_tokens[i] = len(draft_token_ids) - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens) + spec_decode_metadata = self._prepare_spec_decode_metadata( + req_ids, + scheduler_output.scheduled_spec_decode_tokens, + query_start_loc, + ) logits_indices = spec_decode_metadata.logits_indices logits_indices_padded = None @@ -643,9 +630,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] ) - slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc, self.positions.gpu[:total_num_scheduled_tokens]) - # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] @@ -689,7 +673,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, + max_query_len=max_query_len, max_seq_len=max_seq_len, block_table_tensor=blk_table_tensor, slot_mapping=slot_mapping, @@ -734,7 +718,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_np=idx_mapping_np, num_reqs=num_reqs, total_num_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, + max_query_len=max_query_len, attn_metadata=attn_metadata, spec_decode_metadata=spec_decode_metadata, spec_decode_common_attn_metadata=spec_decode_common_attn_metadata, @@ -836,17 +820,44 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return common_prefix_len if use_cascade else 0 - def _calc_mrope_positions(self, input_batch: InputBatch): - mrope_pos_ptr = 0 - for i, req_id in enumerate(input_batch.req_ids): - req = self.requests[req_id] - assert req.mrope_positions is not None + def _prepare_spec_decode_metadata( + self, + req_ids: list[str], + req_id_to_draft_token_ids: dict[str, list[int]], + query_start_loc: torch.Tensor, + ) -> SpecDecodeMetadata: + # Get the number of draft tokens for each request. + num_reqs = len(req_ids) + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for i, req_id in enumerate(req_ids): + draft_token_ids = req_id_to_draft_token_ids.get(req_id) + if draft_token_ids: + num_draft_tokens[i] = len(draft_token_ids) + np.cumsum(num_draft_tokens, + dtype=np.int32, + out=self.cu_num_draft_tokens.np[:num_reqs]) + cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs) + return self.requests.make_spec_decode_metadata( + query_start_loc, + cu_num_draft_tokens, + cu_num_draft_tokens.np[:num_reqs], + self.input_ids.gpu, + ) - num_computed_tokens = \ - self.requests.num_computed_tokens_cpu[i] - num_scheduled_tokens = \ - input_batch.num_scheduled_tokens[i] - num_prompt_tokens = len(req.prompt_token_ids) + def _calc_mrope_positions( + self, + req_ids: list[str], + query_lens: np.ndarray, + ): + mrope_pos_ptr = 0 + for i, req_id in enumerate(req_ids): + req_idx = self.requests.req_id_to_index[req_id] + req_data = self.requests.req_data[req_idx] + assert req_data.mrope_positions is not None + + num_computed_tokens = self.requests.num_computed_tokens.np[req_idx] + num_scheduled_tokens = query_lens[i] + num_prompt_tokens = self.requests.num_prompt_tokens.np[req_idx] if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, @@ -867,7 +878,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): src_end = num_computed_tokens + prompt_part_len self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) + req_data.mrope_positions[:, src_start:src_end]) mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -878,49 +889,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): MRotaryEmbedding.get_next_input_positions_tensor( out=self.mrope_positions.np, out_offset=dst_start, - mrope_position_delta=req.mrope_position_delta, + mrope_position_delta=req_data.mrope_position_delta, context_len=num_computed_tokens + prompt_part_len, num_new_tokens=completion_part_len, ) mrope_pos_ptr += completion_part_len - def _calc_spec_decode_metadata( - self, - num_draft_tokens: np.ndarray, - ) -> SpecDecodeMetadata: - num_reqs = num_draft_tokens.shape[0] - total_num_draft_tokens = prepare_spec_decode( - self.query_start_loc.np, - num_draft_tokens, - self.cu_num_draft_tokens.np, - self.logits_indices.np, - self.target_logits_indices.np, - self.bonus_logits_indices.np, - ) - - cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs) - logits_indices = self.logits_indices.copy_to_gpu( - num_reqs + total_num_draft_tokens) - target_logits_indices = self.target_logits_indices.copy_to_gpu( - total_num_draft_tokens) - bonus_logits_indices = self.bonus_logits_indices.copy_to_gpu(num_reqs) - - # Compute the draft token ids. - # draft_token_indices: [ 1, 2, 3, 105, 106, 208] - draft_token_ids = self.input_ids.gpu[logits_indices] - draft_token_ids = draft_token_ids[target_logits_indices + 1] - - metadata = SpecDecodeMetadata( - draft_token_ids=draft_token_ids, - num_draft_tokens=num_draft_tokens.tolist(), - cu_num_draft_tokens=cu_num_draft_tokens, - target_logits_indices=target_logits_indices, - bonus_logits_indices=bonus_logits_indices, - logits_indices=logits_indices, - ) - return metadata - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: @@ -1353,7 +1328,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): inputs_embeds = None model_kwargs = self._init_model_kwargs(num_input_tokens) if self.uses_mrope: - positions = self.mrope_positions.gpu[:, :num_input_tokens] + positions = self.mrope_positions[:, :num_input_tokens] else: positions = self.positions.gpu[:num_input_tokens] @@ -2117,7 +2092,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): model_kwargs = self._init_model_kwargs(num_tokens) if self.uses_mrope: - positions = self.mrope_positions.gpu[:, :num_tokens] + positions = self.mrope_positions[:, :num_tokens] else: positions = self.positions.gpu[:num_tokens] diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index ab07237cc1c7d..18a47ba3b5506 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -5,10 +5,12 @@ from dataclasses import dataclass from typing import Optional, Union +import numba import numpy as np import torch import triton import triton.language as tl +from numba import types from typing_extensions import deprecated from vllm.lora.request import LoRARequest @@ -18,6 +20,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @dataclass @@ -158,6 +161,10 @@ class RequestState: is_scalar=num_cols == 1, ) + @property + def num_cached_reqs(self) -> int: + return len(self.req_id_to_index) + def add_request( self, req_id: str, @@ -292,9 +299,43 @@ class RequestState: logitsprocs=None, ) - @property - def num_cached_reqs(self) -> int: - return len(self.req_id_to_index) + def make_spec_decode_metadata( + self, + query_start_loc: torch.Tensor, + cu_num_draft_tokens: torch.Tensor, + cu_num_draft_tokens_np: np.ndarray, + input_ids: torch.Tensor, + ) -> SpecDecodeMetadata: + batch_size = query_start_loc.shape[0] - 1 + total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1] + logits_indices = torch.empty(total_num_draft_tokens + batch_size, + dtype=torch.int32, + device=self.device) + target_logits_indices = torch.empty(total_num_draft_tokens, + dtype=torch.int32, + device=self.device) + bonus_logits_indices = torch.empty(batch_size, + dtype=torch.int32, + device=self.device) + _prepare_spec_decode_kernel[(batch_size, )]( + query_start_loc, + cu_num_draft_tokens, + logits_indices, + target_logits_indices, + bonus_logits_indices, + BLOCK_SIZE=triton.next_power_of_2(32 + 1), + ) + + draft_token_ids = input_ids[logits_indices] + draft_token_ids = draft_token_ids[target_logits_indices + 1] + return SpecDecodeMetadata( + draft_token_ids=draft_token_ids, + num_draft_tokens=cu_num_draft_tokens_np.tolist(), + cu_num_draft_tokens=cu_num_draft_tokens, + target_logits_indices=target_logits_indices, + bonus_logits_indices=bonus_logits_indices, + logits_indices=logits_indices, + ) @triton.jit @@ -333,3 +374,90 @@ def _make_sampling_metadata_kernel( repetition_penalties = tl.load(src_repetition_penalties + req_idx) tl.store(dst_repetition_penalties + batch_idx, repetition_penalties) + + +def _prepare_spec_decode_kernel( + query_start_loc, # [B + 1] + cu_num_draft_tokens, # [B] + logits_indices, # [N + B] + target_logits_indices, # [N] + bonus_logits_indices, # [B] + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + + if batch_idx == 0: + draft_start_idx = 0 + else: + draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1) + draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx) + draft_len = draft_end_idx - draft_start_idx + sample_len = draft_len + 1 + + q_end_idx = tl.load(query_start_loc + batch_idx + 1) + + sample_start_idx = draft_start_idx + batch_idx + sample_end_idx = sample_start_idx + sample_len + offset = tl.arange(0, BLOCK_SIZE) + tl.store(logits_indices + sample_start_idx + offset, + q_end_idx - sample_len + offset, + mask=offset < sample_len) + tl.store(target_logits_indices + draft_start_idx + offset, + sample_start_idx + offset, + mask=offset < draft_len) + tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1) + + +# NOTE: With the type annotations, this function is pre-compiled +# before the first call. +@numba.jit( + [ + types.none( + types.int32[:], # idx_mapping + types.int32[:, :], # token_ids + types.int32[:], # num_computed_tokens + types.int32[:], # num_scheduled_tokens + types.int32[:], # input_ids + types.int32[:], # query_start_loc + types.int32[:], # seq_lens + types.int64[:], # positions + ) + ], + nopython=True, + cache=True, +) +def prepare_inputs( + idx_mapping: np.ndarray, # batch_idx -> req_idx + token_ids: np.ndarray, # [N, max_model_len] + num_computed_tokens: np.ndarray, # [N] + num_scheduled_tokens: np.ndarray, # [B] + input_ids: np.ndarray, # [num_input_tokens] + query_start_loc: np.ndarray, # [B + 1] + seq_lens: np.ndarray, # [B] + positions: np.ndarray, # [num_input_tokens] +) -> None: + num_reqs = num_scheduled_tokens.shape[0] + query_start_loc[0] = 0 + + cu_num_tokens = 0 + for i in range(num_reqs): + req_idx = idx_mapping[i] + query_len = num_scheduled_tokens[i] + start = num_computed_tokens[req_idx] + end = start + query_len + seq_lens[i] = end + + start_idx = cu_num_tokens + end_idx = start_idx + query_len + input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] + positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64) + + cu_num_tokens = end_idx + query_start_loc[i + 1] = cu_num_tokens + + # Pad the inputs for CUDA graphs. + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + query_start_loc[num_reqs + 1:].fill(cu_num_tokens) + # Fill unused with 0 for full cuda graph mode. + seq_lens[num_reqs:].fill(0)