From 627efde8133c6982ac69b73ae50aed06151ba603 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 4 Feb 2025 22:16:19 +0000 Subject: [PATCH] fixes --- vllm/v1/worker/tpu_model_runner.py | 147 +++++++++++++++-------------- 1 file changed, 75 insertions(+), 72 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f9dab7eafbee2..3fd235d4a7acc 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -33,11 +33,13 @@ logger = init_logger(__name__) # Here we utilize the behavior that out-of-bound index is ignored. # FIXME(woosuk): Find a more reliable way to prevent possible bugs. _PAD_SLOT_ID = 1_000_000_000 -# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. -_ENABLE_TOP_P = False -# FIXME(woosuk): A temporary hack to support `n > 1`. -# This can significantly affect the performance if too large. -_MAX_NUM_SAMPLES = 128 + + +@dataclass +class PromptDecodeInfo: + prompt_req_ids: List[str] + decode_req_ids: List[str] + prompt_scheduled_tokens: List[int] @dataclass @@ -63,55 +65,42 @@ class TPUModelRunner(ModelRunnerBase): ): super().__init__(vllm_config, device) - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - ) - - # Request states. - self.requests: Dict[str, CachedRequestState] = {} - # KV caches for forward pass self.kv_caches: List[Tuple[torch.Tensor, torch.Tensor]] = [] - # Cache torch/numpy tensors - self.input_ids_cpu = torch.zeros(self.max_num_tokens, + # Cached torch/numpy tensors + self.input_ids_cpu = torch.empty(self.max_num_tokens, dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + device="cpu") self.input_ids_np = self.input_ids_cpu.numpy() - self.input_positions_cpu = torch.empty(self.max_model_len, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) + self.input_positions_cpu = torch.empty(self.max_num_tokens, + dtype=torch.int32, + device="cpu") self.input_positions_np = self.input_positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) + self.slot_mapping_cpu = torch.empty(self.max_num_tokens, + dtype=torch.int64, + device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() - self.prompt_context_lens_cpu = torch.zeros((1), + self.prompt_context_lens_cpu = torch.empty((1), dtype=torch.int32, device="cpu") - self.prompt_effective_query_lens = torch.zeros((1), - dtype=torch.int32, - device="cpu") + self.prompt_effective_query_lens_cpu = torch.empty((1), + dtype=torch.int32, + device="cpu") - self.decode_context_lens_cpu = torch.zeros(self.max_model_len, + self.decode_context_lens_cpu = torch.empty(self.max_num_tokens, dtype=torch.int32, device="cpu") self.decode_context_lens_np = self.decode_context_lens_cpu.numpy() - self.arange_np = np.arange(self.max_model_len, dtype=np.int32) + # Range tensor with values [0 .. self.max_num_tokens - 1]. + # Used to initialize positions / context_lens / seq_lens + self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) + # Cached lists self.req_ids = [] self.prompt_token_ids = [] self.sampled_token_ids = [] @@ -119,7 +108,7 @@ class TPUModelRunner(ModelRunnerBase): def _get_prompts_and_decodes( self, scheduler_output: "SchedulerOutput", - ): + ) -> PromptDecodeInfo: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 num_reqs = self.input_batch.num_reqs @@ -157,10 +146,11 @@ class TPUModelRunner(ModelRunnerBase): # Must be prompt assert num_computed_tokens < num_prompt_tokens - prompt_scheduled_tokens.append(num_scheduled_tokens) prompt_req_ids.append(req_id) + prompt_scheduled_tokens.append(num_scheduled_tokens) - return prompt_req_ids, decode_req_ids, prompt_scheduled_tokens + return PromptDecodeInfo(prompt_req_ids, decode_req_ids, + prompt_scheduled_tokens) def _prepare_prompt(self, req_index: int, num_scheduled_tokens: int) -> PromptData: @@ -203,7 +193,7 @@ class TPUModelRunner(ModelRunnerBase): np.add(block_numbers_np * self.block_size, block_offsets_np, out=slot_mapping_np) - slot_mapping_np[:, prompt_len:] = _PAD_SLOT_ID + slot_mapping_np[prompt_len:] = _PAD_SLOT_ID # Block table block_table_cpu = None @@ -217,7 +207,7 @@ class TPUModelRunner(ModelRunnerBase): self.prompt_context_lens_cpu[0] = seq_len # Effective query len - self.prompt_effective_query_lens[0] = prompt_len + self.prompt_effective_query_lens_cpu[0] = prompt_len # Get final tensors input_tokens = input_tokens_cpu.reshape(1, -1).to(self.device) @@ -230,7 +220,7 @@ class TPUModelRunner(ModelRunnerBase): context_lens = self.prompt_context_lens_cpu.reshape(1, -1).to(self.device) - effective_query_lens = self.prompt_effective_query_lens.reshape( + effective_query_lens = self.prompt_effective_query_lens_cpu.reshape( 1, -1).to(self.device) # Attn metadata @@ -263,7 +253,7 @@ class TPUModelRunner(ModelRunnerBase): 0, out=input_positions_np) input_positions_np[batch_size:] = 0 - input_positions_cpu = torch.from_numpy(input_positions_np) + input_positions_cpu = self.input_positions_cpu[:padded_batch_size] # Input tokens input_tokens_cpu = self.input_ids_cpu[:padded_batch_size] @@ -334,29 +324,31 @@ class TPUModelRunner(ModelRunnerBase): ensure_decodes_first(self.input_batch) # Prepare prompts/decodes info - prompt_req_ids, decode_req_ids, prompt_scheduled_tokens = self._get_prompts_and_decodes( - scheduler_output) + pd_info = self._get_prompts_and_decodes(scheduler_output) # Init - decode_token_ids = None + num_prompts = len(pd_info.prompt_req_ids) + num_decodes = len(pd_info.decode_req_ids) + decode_token_ids_list = None decode_data = None self.req_ids.clear() self.prompt_token_ids.clear() self.sampled_token_ids.clear() - # Run each prompt + # Run each prompt individually is_first = True - for i, req_id in enumerate(prompt_req_ids): - req_index = len(decode_req_ids) + i + for i in range(num_prompts): + req_id = pd_info.prompt_req_ids[i] + req_index = num_decodes + i req_state = self.requests[req_id] - num_scheduled_tokens = prompt_scheduled_tokens[i] - seq_len = req_state.num_computed_tokens + num_scheduled_tokens + num_scheduled_tokens = pd_info.prompt_scheduled_tokens[i] prompt_len = num_scheduled_tokens + seq_len = req_state.num_computed_tokens + num_scheduled_tokens # Prepare first prompt if is_first: prompt_data = self._prepare_prompt(req_index, - prompt_scheduled_tokens[i]) + num_scheduled_tokens) is_first = False # Run forward pass @@ -369,29 +361,36 @@ class TPUModelRunner(ModelRunnerBase): self.kv_caches) # In parallel to TPU execution, prepare the next iteration - if i < len(prompt_req_ids) - 1: + if i < num_prompts - 1: + # There is next prompt => prepare it prompt_data = self._prepare_prompt( - req_index + 1, prompt_scheduled_tokens[i + 1]) - elif i == len(prompt_req_ids) - 1 and len(decode_req_ids) > 0: - decode_data = self._prepare_decode(decode_req_ids) + req_index + 1, pd_info.prompt_scheduled_tokens[i + 1]) + elif i == num_prompts - 1 and num_decodes > 0: + # There is next decode => prepare it + decode_data = self._prepare_decode(pd_info.decode_req_ids) - # Update cached state + # Update cached state (if prompt is fully done) if seq_len >= len(req_state.prompt_token_ids): # Transfer sampled tokens from TPU to CPU - token_id = selected_token_ids.cpu()[prompt_len - 1].item() + selected_token_ids_cpu = selected_token_ids.cpu() + + # Get output token + token_id = selected_token_ids_cpu[prompt_len - 1].item() self.prompt_token_ids.append(token_id) - # Update cached state + # Add output token to the request self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) # Run decodes (a single batch) - if len(decode_req_ids) > 0: - if decode_data is None: - decode_data = self._prepare_decode(decode_req_ids) + if num_decodes > 0: - # Forward + # Prepare decode (if was not yet prepared) + if decode_data is None: + decode_data = self._prepare_decode(pd_info.decode_req_ids) + + # Run forward pass with set_forward_context(decode_data.attn_metadata, self.vllm_config): assert self.model is not None @@ -401,26 +400,30 @@ class TPUModelRunner(ModelRunnerBase): self.kv_caches) # Transfer sampled tokens from TPU to CPU - decode_token_ids = selected_token_ids.cpu().tolist() + decode_token_ids_cpu = selected_token_ids.cpu() + # Convert to list + decode_token_ids_list = decode_token_ids_cpu.tolist() - # Update cached state - for i, req_id in enumerate(decode_req_ids): + # Update cached state for each decode request + for i in range(num_decodes): + req_id = pd_info.decode_req_ids[i] req_index = i req_state = self.requests[req_id] seq_len = req_state.num_computed_tokens + 1 - token_id = decode_token_ids[i] + token_id = decode_token_ids_list[i] self.input_batch.token_ids_cpu[req_index, seq_len] = token_id self.input_batch.num_tokens[req_index] += 1 req_state.output_token_ids.append(token_id) # Create final req_id => token lists. - # This must match the actual batch index positions - self.req_ids.extend(decode_req_ids) - self.req_ids.extend(prompt_req_ids) - if decode_token_ids is not None: - self.sampled_token_ids.extend(decode_token_ids) + # This must match the actual batch index positions, + # so we put decodes first and then prompts. + self.req_ids.extend(pd_info.decode_req_ids) + self.req_ids.extend(pd_info.prompt_req_ids) + if decode_token_ids_list is not None: + self.sampled_token_ids.extend(decode_token_ids_list) self.sampled_token_ids.extend(self.prompt_token_ids) # Create output