From 996b92ccb4fcb91529a22a5d83725d8d80b4156e Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Wed, 5 Feb 2025 20:28:33 +0000 Subject: [PATCH] swap works! --- vllm/v1/worker/gpu_input_batch.py | 15 +- vllm/v1/worker/tpu_model_runner.py | 259 +++++++++++++++++------------ 2 files changed, 161 insertions(+), 113 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b9577f85dc8cf..864601c48b075 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -448,8 +448,8 @@ def swap_positions(b: InputBatch, id_1, id_2): assert id_2 == b.req_id_to_index[req_id_2] b.req_ids[id_1], b.req_ids[id_2] = b.req_ids[id_2], b.req_ids[id_1] - b.req_id_to_index[id_1], b.req_id_to_index[id_2] = b.req_id_to_index[ - id_2], b.req_id_to_index[id_1] + b.req_id_to_index[req_id_1], b.req_id_to_index[ + req_id_2] = b.req_id_to_index[req_id_2], b.req_id_to_index[req_id_1] ids = [id_1, id_2] rev_ids = [id_2, id_1] @@ -471,8 +471,13 @@ def swap_positions(b: InputBatch, id_1, id_2): id_1] b.stop_token_ids[id_1], b.stop_token_ids[id_2] = b.stop_token_ids[ id_2], b.stop_token_ids[id_1] - b.generators[id_1], b.generators[id_2] = b.generators[id_2], b.generators[ - id_1] + + gen_1 = b.generators.pop(id_1, None) + gen_2 = b.generators.pop(id_2, None) + if gen_1 is not None: + b.generators[id_2] = gen_1 + if gen_2 is not None: + b.generators[id_1] = gen_2 def ensure_decodes_first(b: InputBatch): @@ -504,6 +509,4 @@ def ensure_decodes_first(b: InputBatch): break # Swap - print("Swapping first_prompt_index = {} with last_decode_index = {}". - format(first_prompt_index, last_decode_index)) swap_positions(b, first_prompt_index, last_decode_index) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 9cd117baf318c..31301ff0e21a7 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -69,37 +69,57 @@ class TPUModelRunner(ModelRunnerBase): self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] # Cached torch/numpy tensors - self.input_ids_cpu = torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu") - self.input_ids_np = self.input_ids_cpu.numpy() + self.num_swaps = 2 + self.cur_swap_id = 0 + self.input_ids_cpu = [] + self.input_ids_np = [] + self.input_positions_cpu = [] + self.input_positions_np = [] + self.slot_mapping_cpu = [] + self.slot_mapping_np = [] + self.prompt_context_lens_cpu = [] + self.prompt_effective_query_lens_cpu = [] + self.decode_context_lens_cpu = [] + self.decode_context_lens_np = [] + for _ in range(self.num_swaps): + self.input_ids_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu")) + self.input_ids_np.append(self.input_ids_cpu[-1].numpy()) - self.input_positions_cpu = torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu") - self.input_positions_np = self.input_positions_cpu.numpy() + self.input_positions_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu")) + self.input_positions_np.append( + self.input_positions_cpu[-1].numpy()) - self.slot_mapping_cpu = torch.empty(self.max_num_tokens, - dtype=torch.int64, - device="cpu") - self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.slot_mapping_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int64, + device="cpu")) + self.slot_mapping_np.append(self.slot_mapping_cpu[-1].numpy()) - self.prompt_context_lens_cpu = torch.empty((1), - dtype=torch.int32, - device="cpu") - self.prompt_effective_query_lens_cpu = torch.empty((1), - dtype=torch.int32, - device="cpu") + self.prompt_context_lens_cpu.append( + torch.empty((1), dtype=torch.int32, device="cpu")) + self.prompt_effective_query_lens_cpu.append( + torch.empty((1), dtype=torch.int32, device="cpu")) - self.decode_context_lens_cpu = torch.empty(self.max_num_tokens, - dtype=torch.int32, - device="cpu") - self.decode_context_lens_np = self.decode_context_lens_cpu.numpy() + self.decode_context_lens_cpu.append( + torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu")) + self.decode_context_lens_np.append( + self.decode_context_lens_cpu[-1].numpy()) # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) + def swap_step(self): + self.cur_swap_id = (self.cur_swap_id + 1) % self.num_swaps + def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", @@ -165,31 +185,35 @@ class TPUModelRunner(ModelRunnerBase): seq_len = num_computed_tokens + prompt_len padded_seq_len = num_computed_tokens + padded_prompt_len - print("_prepare_prompt:") - print(" prompt_len = {}".format(prompt_len)) - print(" padded_prompt_len = {}".format(padded_prompt_len)) - print(" num_computed_tokens = {}".format(num_computed_tokens)) - print(" num_prompt_tokens = {}".format(num_prompt_tokens)) - print(" seq_len = {}".format(seq_len)) - print(" padded_seq_len = {}".format(padded_seq_len)) + # DEBUG + # print("_prepare_prompt:") + # print(" prompt_len = {}".format(prompt_len)) + # print(" padded_prompt_len = {}".format(padded_prompt_len)) + # print(" num_computed_tokens = {}".format(num_computed_tokens)) + # print(" num_prompt_tokens = {}".format(num_prompt_tokens)) + # print(" seq_len = {}".format(seq_len)) + # print(" padded_seq_len = {}".format(padded_seq_len)) # Input tokens input_tokens_cpu = self.input_batch.token_ids_cpu_tensor[ req_index, num_computed_tokens:padded_seq_len] input_tokens_cpu[prompt_len:] = 0 - print(" input_tokens_cpu.shape = {} val = {}".format( - input_tokens_cpu.shape, input_tokens_cpu)) + # DEBUG + # print(" input_tokens_cpu.shape = {} val = {}".format( + # input_tokens_cpu.shape, input_tokens_cpu)) # Input positions - input_positions_np = self.input_positions_np[:padded_prompt_len] + input_positions_np = self.input_positions_np[ + self.cur_swap_id][:padded_prompt_len] np.add(num_computed_tokens, self.arange_np[:padded_prompt_len], out=input_positions_np) input_positions_np[prompt_len:] = 0 - print(" input_positions_np.shape = {} val = {}".format( - input_positions_np.shape, input_positions_np)) + # DEBUG + # print(" input_positions_np.shape = {} val = {}".format( + # input_positions_np.shape, input_positions_np)) # Slot mapping block_table_np = \ @@ -198,14 +222,16 @@ class TPUModelRunner(ModelRunnerBase): self.block_size] block_offsets_np = input_positions_np % self.block_size - slot_mapping_np = self.slot_mapping_np[:padded_prompt_len] + slot_mapping_np = self.slot_mapping_np[ + self.cur_swap_id][:padded_prompt_len] np.add(block_numbers_np * self.block_size, block_offsets_np, out=slot_mapping_np) slot_mapping_np[prompt_len:] = _PAD_SLOT_ID - print(" slot_mapping_np.shape = {} val = {}".format( - slot_mapping_np.shape, slot_mapping_np)) + # DEBUG + # print(" slot_mapping_np.shape = {} val = {}".format( + # slot_mapping_np.shape, slot_mapping_np)) # Block table block_table_cpu = None @@ -213,40 +239,47 @@ class TPUModelRunner(ModelRunnerBase): block_table_cpu = self.input_batch.block_table.get_cpu_tensor() block_table_cpu = block_table_cpu[req_index] - print(" block_table_cpu = {}".format(block_table_cpu)) + # DEBUG + # print(" block_table_cpu = {}".format(block_table_cpu)) # Context len - self.prompt_context_lens_cpu[0] = 0 + self.prompt_context_lens_cpu[self.cur_swap_id][0] = 0 if num_computed_tokens > 0: - self.prompt_context_lens_cpu[0] = seq_len + self.prompt_context_lens_cpu[self.cur_swap_id][0] = seq_len # Effective query len - self.prompt_effective_query_lens_cpu[0] = prompt_len + self.prompt_effective_query_lens_cpu[self.cur_swap_id][0] = prompt_len # Get final tensors input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) - input_positions = self.input_positions_cpu[:padded_prompt_len].reshape( - 1, -1).to(self.device) - slot_mapping = self.slot_mapping_cpu[:padded_prompt_len].reshape( - 1, -1).to(self.device) + input_positions = self.input_positions_cpu[ + self.cur_swap_id][:padded_prompt_len].reshape(1, + -1).to(self.device) + slot_mapping = self.slot_mapping_cpu[ + self.cur_swap_id][:padded_prompt_len].reshape(1, + -1).to(self.device) block_table = block_table_cpu.reshape(1, -1).to( self.device) if block_table_cpu is not None else None - context_lens = self.prompt_context_lens_cpu.to(self.device) - effective_query_lens = self.prompt_effective_query_lens_cpu.to( + context_lens = self.prompt_context_lens_cpu[self.cur_swap_id].to( self.device) + effective_query_lens = self.prompt_effective_query_lens_cpu[ + self.cur_swap_id].to(self.device) - print(" input_tokens.shape = {} val = {}".format( - input_tokens.shape, input_tokens)) - print(" input_positions.shape = {} val = {}".format( - input_positions.shape, input_positions)) - print(" slot_mapping.shape = {} val = {}".format( - slot_mapping.shape, slot_mapping)) - print(" block_table = {}".format(block_table)) - print(" context_lens.shape = {} val = {}".format( - context_lens.shape, context_lens)) - print(" effective_query_lens.shape = {} val = {}".format( - effective_query_lens.shape, effective_query_lens)) + self.swap_step() + + # DEBUG + # print(" input_tokens.shape = {} val = {}".format( + # input_tokens.shape, input_tokens)) + # print(" input_positions.shape = {} val = {}".format( + # input_positions.shape, input_positions)) + # print(" slot_mapping.shape = {} val = {}".format( + # slot_mapping.shape, slot_mapping)) + # print(" block_table = {}".format(block_table)) + # print(" context_lens.shape = {} val = {}".format( + # context_lens.shape, context_lens)) + # print(" effective_query_lens.shape = {} val = {}".format( + # effective_query_lens.shape, effective_query_lens)) # Attn metadata attn_metadata = PallasMetadata( @@ -275,78 +308,91 @@ class TPUModelRunner(ModelRunnerBase): # Init [0 .. batch_size - 1] req_indices_np = self.arange_np[:padded_batch_size] - print("_prepare_decode:") - print(" batch_size = {}".format(batch_size)) - print(" padded_batch_size = {}".format(padded_batch_size)) - print(" req_indices_np.shape = {} val = {}".format( - req_indices_np.shape, req_indices_np)) + # DEBUG + # print("_prepare_decode:") + # print(" batch_size = {}".format(batch_size)) + # print(" padded_batch_size = {}".format(padded_batch_size)) + # print(" req_indices_np.shape = {} val = {}".format( + # req_indices_np.shape, req_indices_np)) # Input positions - input_positions_np = self.input_positions_np[:padded_batch_size] + input_positions_np = self.input_positions_np[ + self.cur_swap_id][:padded_batch_size] np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], 0, out=input_positions_np) input_positions_np[batch_size:] = 0 - input_positions_cpu = self.input_positions_cpu[:padded_batch_size] + input_positions_cpu = self.input_positions_cpu[ + self.cur_swap_id][:padded_batch_size] - print(" input_positions_cpu.shape = {} data = {}".format( - input_positions_cpu.shape, input_positions_cpu)) + # DEBUG + # print(" input_positions_cpu.shape = {} data = {}".format( + # input_positions_cpu.shape, input_positions_cpu)) # Input tokens token_indices_np = ( input_positions_np + req_indices_np * self.input_batch.token_ids_cpu.shape[1]) - input_tokens_cpu = self.input_ids_cpu[:padded_batch_size] + input_tokens_cpu = self.input_ids_cpu[ + self.cur_swap_id][:padded_batch_size] torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, torch.from_numpy(token_indices_np), out=input_tokens_cpu) input_tokens_cpu[batch_size:] = 0 - print(" token_indices_np.shape = {} val = {}".format( - token_indices_np.shape, token_indices_np)) - - print(" input_tokens_cpu.shape = {} data = {}".format( - input_tokens_cpu.shape, input_tokens_cpu)) + # DEBUG + # print(" token_indices_np.shape = {} val = {}".format( + # token_indices_np.shape, token_indices_np)) + # print(" input_tokens_cpu.shape = {} data = {}".format( + # input_tokens_cpu.shape, input_tokens_cpu)) # Slot mapping block_table_indices_np = ( req_indices_np * self.max_num_blocks_per_req + input_positions_np // self.block_size) - print( - " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}" - .format(block_table_indices_np.shape, block_table_indices_np, - self.max_num_blocks_per_req)) + # DEBUG + # print( + # " block_table_indices_np.shape = {} data = {} max_num_blocks_per_req = {}" + # .format(block_table_indices_np.shape, block_table_indices_np, + # self.max_num_blocks_per_req)) + block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - print(" block_table_cpu.shape = {} data = {}".format( - block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10])) + # DEBUG + # print(" block_table_cpu.shape = {} data = {}".format( + # block_table_cpu.shape, block_table_cpu[:padded_batch_size, :10])) block_numbers_np = block_table_cpu.flatten( )[block_table_indices_np].numpy() - print(" block_numbers_np.shape = {} data = {}".format( - block_numbers_np.shape, block_numbers_np)) + # DEBUG + # print(" block_numbers_np.shape = {} data = {}".format( + # block_numbers_np.shape, block_numbers_np)) block_offsets_np = input_positions_np % self.block_size - print(" block_offsets_np.shape = {} data = {}".format( - block_offsets_np.shape, block_offsets_np)) + # DEBUG + # print(" block_offsets_np.shape = {} data = {}".format( + # block_offsets_np.shape, block_offsets_np)) - slot_mapping_np = self.slot_mapping_np[:padded_batch_size] + slot_mapping_np = self.slot_mapping_np[ + self.cur_swap_id][:padded_batch_size] np.add(block_numbers_np * self.block_size, block_offsets_np, out=slot_mapping_np) slot_mapping_np[batch_size:] = _PAD_SLOT_ID - print(" slot_mapping_np.shape = {} data = {}".format( - slot_mapping_np.shape, slot_mapping_np)) + # DEBUG + # print(" slot_mapping_np.shape = {} data = {}".format( + # slot_mapping_np.shape, slot_mapping_np)) block_table_cpu = block_table_cpu[:padded_batch_size] # Context lens - context_lens_np = self.decode_context_lens_np[:padded_batch_size] + context_lens_np = self.decode_context_lens_np[ + self.cur_swap_id][:padded_batch_size] np.add(self.input_batch.num_computed_tokens_cpu[:padded_batch_size], 1, out=context_lens_np) @@ -355,14 +401,18 @@ class TPUModelRunner(ModelRunnerBase): # Get final tensors input_tokens = input_tokens_cpu.reshape(-1, 1).to(self.device) input_positions = input_positions_cpu.reshape(-1, 1).to(self.device) - slot_mapping = self.slot_mapping_cpu[:padded_batch_size].reshape( - -1, 1).to(self.device) + slot_mapping = self.slot_mapping_cpu[ + self.cur_swap_id][:padded_batch_size].reshape(-1, + 1).to(self.device) block_table = block_table_cpu.to(self.device) - context_lens = self.decode_context_lens_cpu[:padded_batch_size].to( - self.device) + context_lens = self.decode_context_lens_cpu[ + self.cur_swap_id][:padded_batch_size].to(self.device) - print(" context_lens.shape = {} val = {}".format( - context_lens.shape, context_lens)) + self.swap_step() + + # DEBUG + # print(" context_lens.shape = {} val = {}".format( + # context_lens.shape, context_lens)) # Attn metadata attn_metadata = PallasMetadata( @@ -399,9 +449,7 @@ class TPUModelRunner(ModelRunnerBase): num_prompts = len(pd_info.prompt_req_ids) num_decodes = len(pd_info.decode_req_ids) decode_data = None - prompt_sampled_token_ids = [] - decode_sampled_token_ids = [] - sampled_token_ids = [] + sampled_token_ids = [0] * self.input_batch.num_reqs # Run each prompt individually is_first = True @@ -446,12 +494,14 @@ class TPUModelRunner(ModelRunnerBase): # Get output token token_id = selected_token_ids_cpu[prompt_len - 1].item() - prompt_sampled_token_ids.append(token_id) + sampled_token_ids[req_index] = token_id + + # DEBUG + # print( + # " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}" + # .format(token_id, prompt_len, req_id, req_index, + # selected_token_ids_cpu)) - print( - " -- Got token_id = {} for prompt_len = {} req_id = {} req_index = {} selected_token_ids_cpu = {}" - .format(token_id, prompt_len, req_id, req_index, - selected_token_ids_cpu)) # Add output token to the request self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 @@ -488,17 +538,12 @@ class TPUModelRunner(ModelRunnerBase): seq_len = req_state.num_computed_tokens + 1 token_id = decode_token_ids_list[i] - decode_sampled_token_ids.append(token_id) + sampled_token_ids[req_index] = token_id self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) - # Create the final sampled token id list. This must match the actual - # batch index positions, so we put decodes first and then prompts. - sampled_token_ids.extend(decode_sampled_token_ids) - sampled_token_ids.extend(prompt_sampled_token_ids) - # Create output model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids,