diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 82e88a7dbf8df..b9577f85dc8cf 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -504,4 +504,6 @@ def ensure_decodes_first(b: InputBatch): break # Swap + print("Swapping first_prompt_index = {} with last_decode_index = {}". + format(first_prompt_index, last_decode_index)) swap_positions(b, first_prompt_index, last_decode_index) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 3fd235d4a7acc..9a8c217ce7b79 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -218,10 +218,9 @@ class TPUModelRunner(ModelRunnerBase): block_table = block_table_cpu.reshape(1, -1).to( self.device) if block_table_cpu is not None else None - context_lens = self.prompt_context_lens_cpu.reshape(1, - -1).to(self.device) - effective_query_lens = self.prompt_effective_query_lens_cpu.reshape( - 1, -1).to(self.device) + context_lens = self.prompt_context_lens_cpu.to(self.device) + effective_query_lens = self.prompt_effective_query_lens_cpu.to( + self.device) # Attn metadata attn_metadata = PallasMetadata( @@ -247,6 +246,15 @@ class TPUModelRunner(ModelRunnerBase): padded_batch_size = _get_padded_batch_size(batch_size) assert padded_batch_size <= self.max_model_len + # Init [0 .. batch_size - 1] + req_indices_np = self.arange_np[:padded_batch_size] + + print("_prepare_decode:") + print(" batch_size = {}".format(batch_size)) + print(" padded_batch_size = {}".format(padded_batch_size)) + print(" req_indices_np.shape = {} val = {}".format( + req_indices_np.shape, req_indices_np)) + # Input positions input_positions_np = self.input_positions_np[:padded_batch_size] np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], @@ -255,29 +263,61 @@ class TPUModelRunner(ModelRunnerBase): input_positions_np[batch_size:] = 0 input_positions_cpu = self.input_positions_cpu[:padded_batch_size] + print(" input_positions_cpu.shape = {} data = {}".format( + input_positions_cpu.shape, input_positions_cpu)) + # Input tokens + token_indices_np = ( + input_positions_np + + req_indices_np * self.input_batch.token_ids_cpu.shape[1]) input_tokens_cpu = self.input_ids_cpu[:padded_batch_size] - torch.index_select(self.input_batch.token_ids_cpu_tensor, - 1, - input_positions_cpu, + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices_np), out=input_tokens_cpu) - input_tokens_cpu[:batch_size] = 0 + input_tokens_cpu[batch_size:] = 0 + + print(" token_indices_np.shape = {} val = {}".format( + token_indices_np.shape, token_indices_np)) + + print(" input_tokens_cpu.shape = {} data = {}".format( + input_tokens_cpu.shape, input_tokens_cpu)) # Slot mapping + block_table_indices_np = ( + req_indices_np * self.max_num_blocks_per_req + + input_positions_np // self.block_size) + + print( + " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}" + .format(block_table_indices_np.shape, block_table_indices_np, + self.max_num_blocks_per_req)) block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers_cpu = torch.index_select( - block_table_cpu, 1, input_positions_cpu // self.block_size) - block_numbers_np = block_numbers_cpu.numpy() + + print(" block_table_cpu.shape = {} data = {}".format( + block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10])) + + block_numbers_np = block_table_cpu.flatten( + )[block_table_indices_np].numpy() + + print(" block_numbers_np.shape = {} data = {}".format( + block_numbers_np.shape, block_numbers_np)) block_offsets_np = input_positions_np % self.block_size + print(" block_offsets_np.shape = {} data = {}".format( + block_offsets_np.shape, block_offsets_np)) + slot_mapping_np = self.slot_mapping_np[:padded_batch_size] np.add(block_numbers_np * self.block_size, block_offsets_np, out=slot_mapping_np) - slot_mapping_np[:, batch_size:] = _PAD_SLOT_ID + slot_mapping_np[batch_size:] = _PAD_SLOT_ID - block_table_cpu = block_table_cpu[:len(decode_req_ids)] + print(" slot_mapping_np.shape = {} data = {}".format( + slot_mapping_np.shape, slot_mapping_np)) + + block_table_cpu = block_table_cpu[:padded_batch_size] # Context lens context_lens_np = self.decode_context_lens_np[:padded_batch_size] @@ -287,14 +327,17 @@ class TPUModelRunner(ModelRunnerBase): context_lens_np[batch_size:] = 0 # Get final tensors - input_tokens = input_tokens_cpu.to(self.device) - input_positions = input_positions_cpu.to(self.device) - slot_mapping = self.slot_mapping_cpu[:padded_batch_size].to( - self.device) + input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) + input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) + slot_mapping = self.slot_mapping_cpu[:padded_batch_size].reshape( + -1, 1).to(self.device) block_table = block_table_cpu.to(self.device) context_lens = self.decode_context_lens_cpu[:padded_batch_size].to( self.device) + print(" context_lens.shape = {} val = {}".format( + context_lens.shape, context_lens)) + # Attn metadata attn_metadata = PallasMetadata( num_prefills=0,