From c2867d5bc152ecffaf9984ff7cb7b9d69f52bf0c Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 4 Feb 2025 21:12:07 +0000 Subject: [PATCH] Optimize decode/prompt prepare code --- vllm/v1/worker/block_table.py | 8 + vllm/v1/worker/gpu_input_batch.py | 69 ++++ vllm/v1/worker/tpu_model_runner.py | 546 +++++++++++++++-------------- 3 files changed, 358 insertions(+), 265 deletions(-) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 26a2084b131fa..366dd21af0020 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -57,6 +57,14 @@ class BlockTable: src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks + def swap_row(self, src: int, tgt: int) -> None: + num_blocks_src = self.num_blocks_per_row[src] + num_blocks_tgt = self.num_blocks_per_row[tgt] + self.num_blocks_per_row[src] = num_blocks_tgt + self.num_blocks_per_row[tgt] = num_blocks_src + + self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] + def commit(self, num_reqs: int) -> None: self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], non_blocking=True) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 07187968bee7c..82e88a7dbf8df 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -436,3 +436,72 @@ class InputBatch: @property def no_prompt_logprob(self) -> bool: return len(self.prompt_logprob_reqs) == 0 + + +def swap_positions(b: InputBatch, id_1, id_2): + assert id_1 != id_2 + req_id_1 = b.req_ids[id_1] + req_id_2 = b.req_ids[id_2] + assert req_id_1 is not None + assert req_id_2 is not None + assert id_1 == b.req_id_to_index[req_id_1] + assert id_2 == b.req_id_to_index[req_id_2] + + b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] + b.req_id_to_index[id_1], b.req_id_to_index[id_2] = b.req_id_to_index[ + id_2], b.req_id_to_index[id_1] + + ids = [id_1, id_2] + rev_ids = [id_2, id_1] + b.num_tokens[ids] = b.num_tokens[rev_ids] + b.token_ids_cpu[ids] = b.token_ids_cpu[rev_ids] + b.num_prompt_tokens[ids] = b.num_prompt_tokens[rev_ids] + b.num_computed_tokens_cpu[ids] = b.num_computed_tokens_cpu[rev_ids] + + b.block_table.swap_row(id_1, id_2) + + b.temperature_cpu[ids] = b.temperature_cpu[rev_ids] + b.top_p_cpu[ids] = b.top_p_cpu[rev_ids] + b.top_k_cpu[ids] = b.top_k_cpu[rev_ids] + b.frequency_penalties_cpu[ids] = b.frequency_penalties_cpu[rev_ids] + b.presence_penalties_cpu[ids] = b.presence_penalties_cpu[rev_ids] + b.repetition_penalties_cpu[ids] = b.repetition_penalties_cpu[rev_ids] + + b.min_tokens[id_1], b.min_tokens[id_2] = b.min_tokens[id_2], b.min_tokens[ + id_1] + b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ + id_2], b.stop_token_ids[id_1] + b.generators[id_1], b.generators[id_2] = b.generators[id_2], b.generators[ + id_1] + + +def ensure_decodes_first(b: InputBatch): + num_reqs = b.num_reqs + while True: + # Find the first prompt index + first_prompt_index = None + for i in range(num_reqs): + if b.num_computed_tokens_cpu[i] < b.num_prompt_tokens[i]: + first_prompt_index = i + break + if first_prompt_index is None: + break + + # Find the last decode index + last_decode_index = None + for i in reversed(range(num_reqs)): + if b.num_computed_tokens_cpu[i] >= b.num_prompt_tokens[i]: + last_decode_index = i + break + if last_decode_index is None: + break + + # Sanity + assert first_prompt_index != last_decode_index + + # Check if done + if first_prompt_index > last_decode_index: + break + + # Swap + swap_positions(b, first_prompt_index, last_decode_index) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index fc659c864abe6..f9dab7eafbee2 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast from unittest.mock import patch +import numpy as np import torch import torch.distributed import torch.nn as nn @@ -20,7 +21,8 @@ from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_input_batch import (CachedRequestState, InputBatch, + ensure_decodes_first) from vllm.v1.worker.model_runner_base import ExecutionMode, ModelRunnerBase if TYPE_CHECKING: @@ -39,22 +41,14 @@ _MAX_NUM_SAMPLES = 128 @dataclass -class PromptInputData: - - req_ids: List - prompt_lens: List - input_tokens: List - input_positions: List - attn_metadata: List - - def zipped(self): - return zip(self.req_ids, self.prompt_lens, self.input_tokens, - self.input_positions, self.attn_metadata) +class PromptData: + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: PallasMetadata @dataclass -class DecodeInputData: - req_ids: List +class DecodeData: input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional[PallasMetadata] = None @@ -85,250 +79,248 @@ class TPUModelRunner(ModelRunnerBase): # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # Used to initialize positions for the individual prefills - self.prefill_input_positions = torch.tensor(range(self.max_model_len), - device="cpu", - dtype=torch.int32).reshape( - 1, -1) + # Cache torch/numpy tensors + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.input_ids_np = self.input_ids_cpu.numpy() - def _prepare_prompt_inputs( + self.input_positions_cpu = torch.empty(self.max_model_len, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.input_positions_np = self.input_positions_cpu.numpy() + + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + + self.prompt_context_lens_cpu = torch.zeros((1), + dtype=torch.int32, + device="cpu") + self.prompt_effective_query_lens = torch.zeros((1), + dtype=torch.int32, + device="cpu") + + self.decode_context_lens_cpu = torch.zeros(self.max_model_len, + dtype=torch.int32, + device="cpu") + self.decode_context_lens_np = self.decode_context_lens_cpu.numpy() + + self.arange_np = np.arange(self.max_model_len, dtype=np.int32) + + self.req_ids = [] + self.prompt_token_ids = [] + self.sampled_token_ids = [] + + def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", - ) -> PromptInputData: + ): total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - req_ids = [] - prompt_lens = [] - input_tokens_list = [] - input_positions_list = [] - attn_metadata_list = [] - for req_id in self.input_batch.req_ids[:num_reqs]: - assert req_id is not None - req_index = self.input_batch.req_id_to_index[req_id] - req_state = self.requests[req_id] + # Traverse decodes first + decode_req_ids = [] + for i in range(num_reqs): + req_id = self.input_batch.req_ids[i] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + num_prompt_tokens = self.input_batch.num_prompt_tokens[i] num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] - num_computed_tokens = req_state.num_computed_tokens - num_prompt_tokens = len(req_state.prompt_token_ids) - # Detect whether this is a prompt (can be full or chunked) - if num_computed_tokens >= num_prompt_tokens: - # This is a decode => Skip - continue + if num_computed_tokens < num_prompt_tokens: + # This is prompt + break - # This is a prompt - req_ids.append(req_id) + # This is decode + assert num_scheduled_tokens == 1 + decode_req_ids.append(req_id) - # Prompt len - prompt_len = num_scheduled_tokens - prompt_lens.append(prompt_len) - padded_prompt_len = _get_padded_prefill_len(prompt_len) - assert padded_prompt_len <= self.max_model_len + # Traverse prompts + prompt_req_ids = [] + prompt_scheduled_tokens = [] + for i in range(len(decode_req_ids), num_reqs): + req_id = self.input_batch.req_ids[i] - # Seq len - seq_len = num_computed_tokens + prompt_len + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[i] + num_prompt_tokens = self.input_batch.num_prompt_tokens[i] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] - # Input tokens - input_tokens = torch.zeros((1, padded_prompt_len), - dtype=torch.int32, - device="cpu") - input_tokens[:, :prompt_len] = torch.from_numpy( - self.input_batch.token_ids_cpu[req_index, - num_computed_tokens:seq_len]) - # input_tokens = torch.from_numpy(self.input_batch.token_ids_cpu[ - # req_index, num_computed_tokens:padded_seq_len].reshape(1, -1)) - # input_tokens[:, prompt_len:] = 0 - input_tokens_list.append(input_tokens.to(self.device)) + # Must be prompt + assert num_computed_tokens < num_prompt_tokens - # Input positions - input_positions = torch.zeros((1, padded_prompt_len), - dtype=torch.int32, - device="cpu") - input_positions[:, : - prompt_len] = self.prefill_input_positions[:, - num_computed_tokens: - seq_len] - # input_positions[:, prompt_len:] = 0 - input_positions_list.append(input_positions.to(self.device)) + prompt_scheduled_tokens.append(num_scheduled_tokens) + prompt_req_ids.append(req_id) - # Slot mapping - block_table_cpu_tensor = \ - self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu_tensor[req_index, - input_positions // - self.block_size].reshape( - 1, -1) + return prompt_req_ids, decode_req_ids, prompt_scheduled_tokens - block_offsets = input_positions % self.block_size - slot_mapping = block_numbers * self.block_size + block_offsets - slot_mapping[:, prompt_len:] = _PAD_SLOT_ID - slot_mapping = slot_mapping.long() + def _prepare_prompt(self, req_index: int, + num_scheduled_tokens: int) -> PromptData: + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ + req_index] + num_prompt_tokens = self.input_batch.num_prompt_tokens[req_index] - # Block table - block_table = None - if num_computed_tokens > 0: - block_table = block_table_cpu_tensor[req_index].unsqueeze(0) - block_table = block_table.to(self.device) + # Must be prompt + assert num_computed_tokens < num_prompt_tokens - # Context len - context_len = 0 - if num_computed_tokens > 0: - context_len = seq_len - context_lens = torch.tensor([context_len], - dtype=torch.int32, - device="cpu") + # Prompt len + prompt_len = num_scheduled_tokens + padded_prompt_len = _get_padded_prompt_len(prompt_len) + assert padded_prompt_len <= self.max_model_len - # Effective query len - effective_query_lens = torch.tensor([prompt_len], - dtype=torch.int32, - device="cpu") + # Seq len + seq_len = num_computed_tokens + prompt_len + padded_seq_len = num_computed_tokens + padded_prompt_len - # Attn metadata - attn_metadata_list.append( - PallasMetadata( - num_prefills=1, - num_prefill_tokens=0, # NOTE: This is not used. - num_decode_tokens=0, - slot_mapping=slot_mapping.to(self.device), - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=True, - block_tables=block_table, - context_lens=context_lens.to(self.device), - effective_query_lens=effective_query_lens.to(self.device), - )) + # Input tokens + input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ + req_index, num_computed_tokens:padded_seq_len] + input_tokens_cpu[prompt_len:] = 0 - # TODO: Remove this - # if num_computed_tokens > 0: - # print("-------------------") - # print("input_tokens.shape = {}".format(input_tokens.shape)) - # print("input_positions.shape = {}".format( - # input_positions.shape)) - # print("slot_mapping.shape = {}".format(slot_mapping.shape)) - # print("block_table.shape = {}".format(block_table.shape)) - # print("context_lens.shape = {} data = {}".format( - # context_lens.shape, context_lens)) - # print("effective_query_lens.shape = {} data = {}".format( - # effective_query_lens.shape, effective_query_lens)) + # Input positions + input_positions_np = self.input_positions_np[:padded_prompt_len] + np.add(num_computed_tokens, + self.arange_np[:padded_prompt_len], + out=input_positions_np) + input_positions_np[prompt_len:] = 0 - return PromptInputData( - req_ids=req_ids, - prompt_lens=prompt_lens, - input_tokens=input_tokens_list, - input_positions=input_positions_list, - attn_metadata=attn_metadata_list, + # Slot mapping + block_table_np = \ + self.input_batch.block_table.get_numpy_array() + block_numbers_np = block_table_np[req_index, input_positions_np // + self.block_size] + block_offsets_np = input_positions_np % self.block_size + + slot_mapping_np = self.slot_mapping_np[:padded_prompt_len] + np.add(block_numbers_np * self.block_size, + block_offsets_np, + out=slot_mapping_np) + slot_mapping_np[:, prompt_len:] = _PAD_SLOT_ID + + # Block table + block_table_cpu = None + if num_computed_tokens > 0: + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_table_cpu = block_table_cpu[req_index] + + # Context len + self.prompt_context_lens_cpu[0] = 0 + if num_computed_tokens > 0: + self.prompt_context_lens_cpu[0] = seq_len + + # Effective query len + self.prompt_effective_query_lens[0] = prompt_len + + # Get final tensors + input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) + input_positions = self.input_positions_cpu[:padded_prompt_len].reshape( + 1, -1).to(self.device) + slot_mapping = self.slot_mapping_cpu[:padded_prompt_len].reshape( + 1, -1).to(self.device) + block_table = block_table_cpu.reshape(1, -1).to( + self.device) if block_table_cpu is not None else None + + context_lens = self.prompt_context_lens_cpu.reshape(1, + -1).to(self.device) + effective_query_lens = self.prompt_effective_query_lens.reshape( + 1, -1).to(self.device) + + # Attn metadata + attn_metadata = PallasMetadata( + num_prefills=1, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + block_tables=block_table, + context_lens=context_lens, + effective_query_lens=effective_query_lens, ) - def _prepare_decode_inputs( + return PromptData(input_tokens, input_positions, attn_metadata) + + def _prepare_decode( self, - scheduler_output: "SchedulerOutput", - ) -> DecodeInputData: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - - block_table_cpu_tensor = self.input_batch.block_table.get_cpu_tensor() - - req_ids = [] - req_indices = [] - input_tokens = [] - input_positions = [] - slot_mapping = [] - context_lens = [] - for req_id in self.input_batch.req_ids[:num_reqs]: - assert req_id is not None - req_index = self.input_batch.req_id_to_index[req_id] - req_state = self.requests[req_id] - - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - num_computed_tokens = req_state.num_computed_tokens - num_prompt_tokens = len(req_state.prompt_token_ids) - - # Detect whether this is a decode - if num_computed_tokens < num_prompt_tokens: - # This is a prompt => Skip - continue - - # This is a decode - req_ids.append(req_id) - req_indices.append(req_index) - - # Seq len - seq_len = num_computed_tokens + num_scheduled_tokens - - # Sanity check decode - assert num_scheduled_tokens == 1 - assert seq_len == req_state.num_tokens - - # Input token - input_tokens.append([ - self.input_batch.token_ids_cpu[req_index, num_computed_tokens] - ]) - - # Position - input_positions.append([num_computed_tokens]) - - # Slot mapping - block_number = block_table_cpu_tensor[req_index, - num_computed_tokens // - self.block_size] - block_offset = num_computed_tokens % self.block_size - slot_id = block_number * self.block_size + block_offset - slot_mapping.append([slot_id]) - - # Context len - context_lens.append(seq_len) - - # Compute padding - batch_size = len(input_tokens) + decode_req_ids: List[str], + ) -> DecodeData: + # Batch size + batch_size = len(decode_req_ids) padded_batch_size = _get_padded_batch_size(batch_size) - num_padding = padded_batch_size - batch_size + assert padded_batch_size <= self.max_model_len - # Add padding - input_tokens.extend([[0]] * num_padding) - input_positions.extend([[0]] * num_padding) - slot_mapping.extend([[_PAD_SLOT_ID]] * num_padding) - context_lens.extend([0] * num_padding) - req_indices.extend([0] * num_padding) + # Input positions + input_positions_np = self.input_positions_np[:padded_batch_size] + np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], + 0, + out=input_positions_np) + input_positions_np[batch_size:] = 0 + input_positions_cpu = torch.from_numpy(input_positions_np) - # Create tensors - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.int32, - device="cpu") - input_positions_tensor = torch.tensor(input_positions, - dtype=torch.int32, - device="cpu") - slot_mapping_tensor = torch.tensor(slot_mapping, - dtype=torch.int64, - device="cpu") - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int32, - device="cpu") - block_tables_tensor = block_table_cpu_tensor[req_indices] + # Input tokens + input_tokens_cpu = self.input_ids_cpu[:padded_batch_size] + torch.index_select(self.input_batch.token_ids_cpu_tensor, + 1, + input_positions_cpu, + out=input_tokens_cpu) + input_tokens_cpu[:batch_size] = 0 + + # Slot mapping + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() + block_numbers_cpu = torch.index_select( + block_table_cpu, 1, input_positions_cpu // self.block_size) + block_numbers_np = block_numbers_cpu.numpy() + + block_offsets_np = input_positions_np % self.block_size + + slot_mapping_np = self.slot_mapping_np[:padded_batch_size] + np.add(block_numbers_np * self.block_size, + block_offsets_np, + out=slot_mapping_np) + slot_mapping_np[:, batch_size:] = _PAD_SLOT_ID + + block_table_cpu = block_table_cpu[:len(decode_req_ids)] + + # Context lens + context_lens_np = self.decode_context_lens_np[:padded_batch_size] + np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], + 1, + out=context_lens_np) + context_lens_np[batch_size:] = 0 + + # Get final tensors + input_tokens = input_tokens_cpu.to(self.device) + input_positions = input_positions_cpu.to(self.device) + slot_mapping = self.slot_mapping_cpu[:padded_batch_size].to( + self.device) + block_table = block_table_cpu.to(self.device) + context_lens = self.decode_context_lens_cpu[:padded_batch_size].to( + self.device) # Attn metadata attn_metadata = PallasMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=padded_batch_size, - slot_mapping=slot_mapping_tensor.to(self.device), + slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, - block_tables=block_tables_tensor.to(self.device), - context_lens=context_lens_tensor.to(self.device), + block_tables=block_table, + context_lens=context_lens, effective_query_lens=None, ) - return DecodeInputData( - req_ids=req_ids, - input_tokens=input_tokens_tensor.to(self.device), - input_positions=input_positions_tensor.to(self.device), - attn_metadata=attn_metadata) + return DecodeData(input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata) @torch.no_grad() def execute_model( @@ -338,17 +330,67 @@ class TPUModelRunner(ModelRunnerBase): # Update cached state self.update_states(scheduler_output) - # Prepare inputs - prompt_data = self._prepare_prompt_inputs(scheduler_output) - decode_data = self._prepare_decode_inputs(scheduler_output) + # If necessary, swap decodes/prompts to have all decodes on the start + ensure_decodes_first(self.input_batch) + + # Prepare prompts/decodes info + prompt_req_ids, decode_req_ids, prompt_scheduled_tokens = self._get_prompts_and_decodes( + scheduler_output) # Init - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - sampled_token_ids_list = [0] * num_reqs + decode_token_ids = None + decode_data = None + self.req_ids.clear() + self.prompt_token_ids.clear() + self.sampled_token_ids.clear() + + # Run each prompt + is_first = True + for i, req_id in enumerate(prompt_req_ids): + req_index = len(decode_req_ids) + i + req_state = self.requests[req_id] + num_scheduled_tokens = prompt_scheduled_tokens[i] + seq_len = req_state.num_computed_tokens + num_scheduled_tokens + prompt_len = num_scheduled_tokens + + # Prepare first prompt + if is_first: + prompt_data = self._prepare_prompt(req_index, + prompt_scheduled_tokens[i]) + is_first = False + + # Run forward pass + with set_forward_context(prompt_data.attn_metadata, + self.vllm_config): + assert self.model is not None + selected_token_ids = self.model(prompt_data.input_tokens, + prompt_data.input_positions, + prompt_data.attn_metadata, + self.kv_caches) + + # In parallel to TPU execution, prepare the next iteration + if i < len(prompt_req_ids) - 1: + prompt_data = self._prepare_prompt( + req_index + 1, prompt_scheduled_tokens[i + 1]) + elif i == len(prompt_req_ids) - 1 and len(decode_req_ids) > 0: + decode_data = self._prepare_decode(decode_req_ids) + + # Update cached state + if seq_len >= len(req_state.prompt_token_ids): + # Transfer sampled tokens from TPU to CPU + token_id = selected_token_ids.cpu()[prompt_len - 1].item() + self.prompt_token_ids.append(token_id) + + # Update cached state + self.input_batch.token_ids_cpu[req_index, seq_len] = token_id + self.input_batch.num_tokens[req_index] += 1 + req_state.output_token_ids.append(token_id) # Run decodes (a single batch) - if len(decode_data.req_ids) > 0: + if len(decode_req_ids) > 0: + if decode_data is None: + decode_data = self._prepare_decode(decode_req_ids) + # Forward with set_forward_context(decode_data.attn_metadata, self.vllm_config): @@ -359,59 +401,33 @@ class TPUModelRunner(ModelRunnerBase): self.kv_caches) # Transfer sampled tokens from TPU to CPU - selected_token_ids_list = selected_token_ids.cpu().tolist() + decode_token_ids = selected_token_ids.cpu().tolist() # Update cached state - for i, req_id in enumerate(decode_data.req_ids): - req_index = self.input_batch.req_id_to_index[req_id] + for i, req_id in enumerate(decode_req_ids): + req_index = i req_state = self.requests[req_id] + seq_len = req_state.num_computed_tokens + 1 - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - - token_id = selected_token_ids_list[i] + token_id = decode_token_ids[i] self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) - sampled_token_ids_list[req_index] = token_id - - # Run each prompt - for (req_id, prompt_len, input_tokens, input_positions, - attn_metadata) in prompt_data.zipped(): - assert req_id is not None - req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Forward - with set_forward_context(attn_metadata, self.vllm_config): - assert self.model is not None - selected_token_ids = self.model(input_tokens, input_positions, - attn_metadata, self.kv_caches) - - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len >= len(req_state.prompt_token_ids): - # Transfer sampled tokens from TPU to CPU - token_id = selected_token_ids.cpu()[prompt_len - 1].item() - sampled_token_ids_list[req_index] = token_id - - # Update cached state - self.input_batch.token_ids_cpu[req_index, seq_len] = token_id - self.input_batch.num_tokens[req_index] += 1 - req_state.output_token_ids.append(token_id) - - # Get req_ids - assert all( - req_id is not None for req_id in - self.input_batch.req_ids[:num_reqs]), "req_ids contains None" - req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) + # Create final req_id => token lists. + # This must match the actual batch index positions + self.req_ids.extend(decode_req_ids) + self.req_ids.extend(prompt_req_ids) + if decode_token_ids is not None: + self.sampled_token_ids.extend(decode_token_ids) + self.sampled_token_ids.extend(self.prompt_token_ids) + # Create output model_runner_output = ModelRunnerOutput( - req_ids=req_ids, + req_ids=self.req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=sampled_token_ids_list, + sampled_token_ids=self.sampled_token_ids, logprob_token_ids_cpu=None, logprobs_cpu=None, ) @@ -710,7 +726,7 @@ class ModelWrapperV1(nn.Module): return argmax_token_ids -def _get_padded_prefill_len(x: int) -> int: +def _get_padded_prompt_len(x: int) -> int: # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence # length to be a multiple of 16. We pad the prompt length to the nearest # multiple of 16. This is also good for performance.