[TPU] Fix tpu structured decoding in mixed batches (#24458)

Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
Chenyaaang 2025-09-09 23:34:25 +05:30 committed by GitHub
parent 3707cb2505
commit c3f9773b2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1769,28 +1769,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.grammar_bitmask_cpu.zero_() self.grammar_bitmask_cpu.zero_()
self.require_structured_out_cpu.zero_() self.require_structured_out_cpu.zero_()
# We receive the structured output bitmask from the scheduler, but the sorted_struct_requests = sorted(
# indices of the requests in the batch may not match the indices of scheduler_output.structured_output_request_ids.items(),
# the bitmask since the scheduler doesn't know how the tpu runner is key=lambda item: item[1])
# ordering the requests in the batch. We need to match the order of cumulative_mask_idx = 0
# bitmask with the order of requests for req_id, _ in sorted_struct_requests:
struct_out_indices: list[int] = [] if req_id not in self.input_batch.req_id_to_index:
mask_indices: list[int] = []
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
continue continue
batch_index = self.input_batch.req_id_to_index[req_id] batch_index = self.input_batch.req_id_to_index[req_id]
struct_out_indices.append(batch_index) self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
mask_indices.append(mask_index) grammar_bitmask[cumulative_mask_idx])
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy( # It's not guaranteed that all requests in this batch require
grammar_bitmask[mask_indices]) # structured output, so create a bool tensor to represent
# It's not guaranteed that all requests in this batch require # the requests that need structured output.
# structured output, so create a bool tensor to represent self.require_structured_out_cpu[batch_index] = True
# the requests that need structured output. cumulative_mask_idx += 1
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
self.require_structured_out_cpu[struct_out_indices] = True
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \ return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
self.structured_decode_arange.to(logits.device) self.structured_decode_arange.to(logits.device)