fix corner case for update_async_output_token_ids

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhuhaoran 2025-12-23 17:52:38 +08:00
parent 699800a28e
commit 8d339e86e5

View File

@ -942,18 +942,15 @@ class InputBatch:
sampled_token_ids = self.sampled_token_ids_cpu.tolist() sampled_token_ids = self.sampled_token_ids_cpu.tolist()
# Replace placeholder token id(s) with actual sampled id(s). # Replace placeholder token id(s) with actual sampled id(s).
if sampled_ids := sampled_token_ids[prev_index]: if sampled_ids := sampled_token_ids[prev_index]:
num_placeholders = 0 num_replace = 0
for t in reversed(req_output_token_ids): for t in sampled_ids:
if t == -1: if t == -1:
num_placeholders += 1
else:
break break
if num_placeholders == 0: num_replace += 1
if num_replace == 0:
continue continue
assert num_placeholders <= len(sampled_ids) req_output_token_ids[-num_replace:] = sampled_ids[:num_replace]
req_output_token_ids[-num_placeholders:] = sampled_ids[
:num_placeholders
]
def update_async_spec_token_ids( def update_async_spec_token_ids(
self, self,