From c3f9773b2c5002a3ba4421f5d7fedf036905d369 Mon Sep 17 00:00:00 2001 From: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com> Date: Tue, 9 Sep 2025 23:34:25 +0530 Subject: [PATCH] [TPU] Fix tpu structured decoding in mixed batches (#24458) Signed-off-by: Chenyaaang --- vllm/v1/worker/tpu_model_runner.py | 34 ++++++++++++------------------ 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5947b54d33ce..15af7ffac809 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1769,28 +1769,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.grammar_bitmask_cpu.zero_() self.require_structured_out_cpu.zero_() - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the tpu runner is - # ordering the requests in the batch. We need to match the order of - # bitmask with the order of requests - struct_out_indices: list[int] = [] - mask_indices: list[int] = [] - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: + sorted_struct_requests = sorted( + scheduler_output.structured_output_request_ids.items(), + key=lambda item: item[1]) + cumulative_mask_idx = 0 + for req_id, _ in sorted_struct_requests: + if req_id not in self.input_batch.req_id_to_index: continue batch_index = self.input_batch.req_id_to_index[req_id] - struct_out_indices.append(batch_index) - mask_indices.append(mask_index) - self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( - grammar_bitmask[mask_indices]) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long) - self.require_structured_out_cpu[struct_out_indices] = True + self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( + grammar_bitmask[cumulative_mask_idx]) + # It's not guaranteed that all requests in this batch require + # structured output, so create a bool tensor to represent + # the requests that need structured output. + self.require_structured_out_cpu[batch_index] = True + cumulative_mask_idx += 1 + return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ self.structured_decode_arange.to(logits.device)