mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[TPU] Fix tpu structured decoding in mixed batches (#24458)
Signed-off-by: Chenyaaang <chenyangli@google.com>
This commit is contained in:
parent
3707cb2505
commit
c3f9773b2c
@ -1769,28 +1769,22 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.grammar_bitmask_cpu.zero_()
|
self.grammar_bitmask_cpu.zero_()
|
||||||
self.require_structured_out_cpu.zero_()
|
self.require_structured_out_cpu.zero_()
|
||||||
|
|
||||||
# We receive the structured output bitmask from the scheduler, but the
|
sorted_struct_requests = sorted(
|
||||||
# indices of the requests in the batch may not match the indices of
|
scheduler_output.structured_output_request_ids.items(),
|
||||||
# the bitmask since the scheduler doesn't know how the tpu runner is
|
key=lambda item: item[1])
|
||||||
# ordering the requests in the batch. We need to match the order of
|
cumulative_mask_idx = 0
|
||||||
# bitmask with the order of requests
|
for req_id, _ in sorted_struct_requests:
|
||||||
struct_out_indices: list[int] = []
|
if req_id not in self.input_batch.req_id_to_index:
|
||||||
mask_indices: list[int] = []
|
|
||||||
for req_id in self.input_batch.req_ids:
|
|
||||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
|
||||||
req_id)
|
|
||||||
if mask_index is None:
|
|
||||||
continue
|
continue
|
||||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||||
struct_out_indices.append(batch_index)
|
self.grammar_bitmask_cpu[batch_index] = torch.from_numpy(
|
||||||
mask_indices.append(mask_index)
|
grammar_bitmask[cumulative_mask_idx])
|
||||||
self.grammar_bitmask_cpu[struct_out_indices] = torch.from_numpy(
|
# It's not guaranteed that all requests in this batch require
|
||||||
grammar_bitmask[mask_indices])
|
# structured output, so create a bool tensor to represent
|
||||||
# It's not guaranteed that all requests in this batch require
|
# the requests that need structured output.
|
||||||
# structured output, so create a bool tensor to represent
|
self.require_structured_out_cpu[batch_index] = True
|
||||||
# the requests that need structured output.
|
cumulative_mask_idx += 1
|
||||||
struct_out_indices = torch.tensor(struct_out_indices, dtype=torch.long)
|
|
||||||
self.require_structured_out_cpu[struct_out_indices] = True
|
|
||||||
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
|
return self.require_structured_out_cpu[:num_reqs].to(logits.device), \
|
||||||
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
|
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
|
||||||
self.structured_decode_arange.to(logits.device)
|
self.structured_decode_arange.to(logits.device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user