mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[TPU] Fix tpu structured decoding in mixed batches (#24458)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
3707cb2505
commit
c3f9773b2c
@ -1769,28 +1769,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.grammar_bitmask_cpu.zero_()
|
||||
self.require_structured_out_cpu.zero_()
|
||||
|
||||
# We receive the structured output bitmask from the scheduler, but the
|
||||
# indices of the requests in the batch may not match the indices of
|
||||
# the bitmask since the scheduler doesn't know how the tpu runner is
|
||||
# ordering the requests in the batch. We need to match the order of
|
||||
# bitmask with the order of requests
|
||||
struct_out_indices: list[int] = []
|
||||
mask_indices: list[int] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||
req_id)
|
||||
if mask_index is None:
|
||||
sorted_struct_requests = sorted(
|
||||
scheduler_output.structured_output_request_ids.items(),
|
||||
key=lambda item: item[1])
|
||||
cumulative_mask_idx = 0
|
||||
for req_id, _ in sorted_struct_requests:
|
||||
if req_id not in self.input_batch.req_id_to_index:
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
struct_out_indices.append(batch_index)
|
||||
mask_indices.append(mask_index)
|
||||
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
|
||||
grammar_bitmask[mask_indices])
|
||||
self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
|
||||
grammar_bitmask[cumulative_mask_idx])
|
||||
# It's not guaranteed that all requests in this batch require
|
||||
# structured output, so create a bool tensor to represent
|
||||
# the requests that need structured output.
|
||||
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
|
||||
self.require_structured_out_cpu[struct_out_indices] = True
|
||||
self.require_structured_out_cpu[batch_index] = True
|
||||
cumulative_mask_idx += 1
|
||||
|
||||
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
|
||||
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
|
||||
self.structured_decode_arange.to(logits.device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user