From 2b0526fa15fcc00c2a0e2d2b1aa9cb1c43a7b2db Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 5 Feb 2025 16:54:57 +0000 Subject: [PATCH] works! --- vllm/v1/worker/tpu_model_runner.py | 70 +++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9a8c217ce7b79..9cd117baf318c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -100,11 +100,6 @@ class TPUModelRunner(ModelRunnerBase): # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) - # Cached lists - self.req_ids = [] - self.prompt_token_ids = [] - self.sampled_token_ids = [] - def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", @@ -170,11 +165,22 @@ class TPUModelRunner(ModelRunnerBase): seq_len = num_computed_tokens + prompt_len padded_seq_len = num_computed_tokens + padded_prompt_len + print("_prepare_prompt:") + print(" prompt_len = {}".format(prompt_len)) + print(" padded_prompt_len = {}".format(padded_prompt_len)) + print(" num_computed_tokens = {}".format(num_computed_tokens)) + print(" num_prompt_tokens = {}".format(num_prompt_tokens)) + print(" seq_len = {}".format(seq_len)) + print(" padded_seq_len = {}".format(padded_seq_len)) + # Input tokens input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ req_index, num_computed_tokens:padded_seq_len] input_tokens_cpu[prompt_len:] = 0 + print(" input_tokens_cpu.shape = {} val = {}".format( + input_tokens_cpu.shape, input_tokens_cpu)) + # Input positions input_positions_np = self.input_positions_np[:padded_prompt_len] np.add(num_computed_tokens, @@ -182,6 +188,9 @@ class TPUModelRunner(ModelRunnerBase): out=input_positions_np) input_positions_np[prompt_len:] = 0 + print(" input_positions_np.shape = {} val = {}".format( + input_positions_np.shape, input_positions_np)) + # Slot mapping block_table_np = \ self.input_batch.block_table.get_numpy_array() @@ -195,12 +204,17 @@ class TPUModelRunner(ModelRunnerBase): out=slot_mapping_np) slot_mapping_np[prompt_len:] = _PAD_SLOT_ID + print(" slot_mapping_np.shape = {} val = {}".format( + slot_mapping_np.shape, slot_mapping_np)) + # Block table block_table_cpu = None if num_computed_tokens > 0: block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_table_cpu = block_table_cpu[req_index] + print(" block_table_cpu = {}".format(block_table_cpu)) + # Context len self.prompt_context_lens_cpu[0] = 0 if num_computed_tokens > 0: @@ -222,6 +236,18 @@ class TPUModelRunner(ModelRunnerBase): effective_query_lens = self.prompt_effective_query_lens_cpu.to( self.device) + print(" input_tokens.shape = {} val = {}".format( + input_tokens.shape, input_tokens)) + print(" input_positions.shape = {} val = {}".format( + input_positions.shape, input_positions)) + print(" slot_mapping.shape = {} val = {}".format( + slot_mapping.shape, slot_mapping)) + print(" block_table = {}".format(block_table)) + print(" context_lens.shape = {} val = {}".format( + context_lens.shape, context_lens)) + print(" effective_query_lens.shape = {} val = {}".format( + effective_query_lens.shape, effective_query_lens)) + # Attn metadata attn_metadata = PallasMetadata( num_prefills=1, @@ -372,17 +398,18 @@ class TPUModelRunner(ModelRunnerBase): # Init num_prompts = len(pd_info.prompt_req_ids) num_decodes = len(pd_info.decode_req_ids) - decode_token_ids_list = None decode_data = None - self.req_ids.clear() - self.prompt_token_ids.clear() - self.sampled_token_ids.clear() + prompt_sampled_token_ids = [] + decode_sampled_token_ids = [] + sampled_token_ids = [] # Run each prompt individually is_first = True for i in range(num_prompts): req_id = pd_info.prompt_req_ids[i] req_index = num_decodes + i + assert req_index == self.input_batch.req_id_to_index[ + req_id] # TODO: Remove req_state = self.requests[req_id] num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] prompt_len = num_scheduled_tokens @@ -419,8 +446,12 @@ class TPUModelRunner(ModelRunnerBase): # Get output token token_id = selected_token_ids_cpu[prompt_len - 1].item() - self.prompt_token_ids.append(token_id) + prompt_sampled_token_ids.append(token_id) + print( + " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}" + .format(token_id, prompt_len, req_id, req_index, + selected_token_ids_cpu)) # Add output token to the request self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 @@ -451,29 +482,28 @@ class TPUModelRunner(ModelRunnerBase): for i in range(num_decodes): req_id = pd_info.decode_req_ids[i] req_index = i + assert req_index == self.input_batch.req_id_to_index[ + req_id] # TODO: Remove req_state = self.requests[req_id] seq_len = req_state.num_computed_tokens + 1 token_id = decode_token_ids_list[i] + decode_sampled_token_ids.append(token_id) self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) - # Create final req_id => token lists. - # This must match the actual batch index positions, - # so we put decodes first and then prompts. - self.req_ids.extend(pd_info.decode_req_ids) - self.req_ids.extend(pd_info.prompt_req_ids) - if decode_token_ids_list is not None: - self.sampled_token_ids.extend(decode_token_ids_list) - self.sampled_token_ids.extend(self.prompt_token_ids) + # Create the final sampled token id list. This must match the actual + # batch index positions, so we put decodes first and then prompts. + sampled_token_ids.extend(decode_sampled_token_ids) + sampled_token_ids.extend(prompt_sampled_token_ids) # Create output model_runner_output = ModelRunnerOutput( - req_ids=self.req_ids, + req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=self.sampled_token_ids, + sampled_token_ids=sampled_token_ids, logprob_token_ids_cpu=None, logprobs_cpu=None, )