fix: use num_draft_tokens to trim draft_token_ids_cpu

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhuhaoran 2025-12-12 22:59:57 +08:00
parent 8c0779a646
commit 33c63f263d
2 changed files with 12 additions and 4 deletions

View File

@ -966,7 +966,9 @@ class InputBatch:
]
def update_async_spec_token_ids(
self, draft_token_ids_cpu: list[list[int]] | None
self,
draft_token_ids_cpu: list[list[int]] | None,
num_draft_tokens: list[int] | None = None,
) -> None:
"""
In async scheduling case, update spec_token_ids in sampling metadata
@ -985,11 +987,14 @@ class InputBatch:
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
assert prev_index < len(draft_token_ids_cpu)
draft_ids = draft_token_ids_cpu[prev_index]
if not draft_ids:
continue
assert index < len(spec_token_ids)
if num_draft_tokens is not None:
scheduled_count = num_draft_tokens[index]
assert scheduled_count <= len(draft_ids)
draft_ids = draft_ids[:scheduled_count]
spec_token_ids[index].clear()
spec_token_ids[index].extend(draft_ids)

View File

@ -2582,7 +2582,10 @@ class GPUModelRunner(
# Update spec_token_ids with real draft tokens from previous step
draft_token_ids_cpu = self._get_draft_token_ids_cpu()
self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
self.input_batch.update_async_spec_token_ids(
draft_token_ids_cpu,
num_draft_tokens=spec_decode_metadata.num_draft_tokens,
)
sampler_output = self.rejection_sampler(
spec_decode_metadata,