[TPU] Fix tpu structured decoding in mixed batches (#24458)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang 2025-09-09 23:34:25 +05:30 committed by GitHub
parent 3707cb2505
commit c3f9773b2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1769,28 +1769,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.grammar_bitmask_cpu.zero_()
self.require_structured_out_cpu.zero_()
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the tpu runner is
# ordering the requests in the batch. We need to match the order of
# bitmask with the order of requests
struct_out_indices: list[int] = []
mask_indices: list[int] = []
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
sorted_struct_requests = sorted(
scheduler_output.structured_output_request_ids.items(),
key=lambda item: item[1])
cumulative_mask_idx = 0
for req_id, _ in sorted_struct_requests:
if req_id not in self.input_batch.req_id_to_index:
continue
batch_index = self.input_batch.req_id_to_index[req_id]
struct_out_indices.append(batch_index)
mask_indices.append(mask_index)
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
grammar_bitmask[mask_indices])
self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
grammar_bitmask[cumulative_mask_idx])
# It's not guaranteed that all requests in this batch require
# structured output, so create a bool tensor to represent
# the requests that need structured output.
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
self.require_structured_out_cpu[struct_out_indices] = True
self.require_structured_out_cpu[batch_index] = True
cumulative_mask_idx += 1
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
self.structured_decode_arange.to(logits.device)