Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhuhaoran 2025-12-12 00:02:47 +08:00
parent 461e7b3e32
commit 78555c06a2
2 changed files with 22 additions and 11 deletions

View File

@ -952,7 +952,18 @@ class InputBatch:
sampled_token_ids = self.sampled_token_ids_cpu.tolist()
# Replace placeholder token id(s) with actual sampled id(s).
if sampled_ids := sampled_token_ids[prev_index]:
req_output_token_ids[-len(sampled_ids) :] = sampled_ids
num_placeholders = 0
for t in reversed(req_output_token_ids):
if t == -1:
num_placeholders += 1
else:
break
if num_placeholders == 0:
continue
assert num_placeholders <= len(sampled_ids)
req_output_token_ids[-num_placeholders:] = sampled_ids[
:num_placeholders
]
def update_async_spec_token_ids(
self, draft_token_ids_cpu: list[list[int]] | None
@ -971,13 +982,16 @@ class InputBatch:
return
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id, default=None)
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
assert prev_index < len(draft_token_ids_cpu)
draft_ids = draft_token_ids_cpu[prev_index]
if not draft_ids:
continue
assert index < len(spec_token_ids)
spec_token_ids[index] = draft_ids
spec_token_ids[index].clear()
spec_token_ids[index].extend(draft_ids)
@property
def num_reqs(self) -> int:

View File

@ -3423,14 +3423,13 @@ class GPUModelRunner(
for async scheduling with spec decode + penalty/bad_words.
Returns None if no draft tokens were copied in previous step.
"""
if not self._has_draft_tokens:
return None
if isinstance(self._draft_token_ids, list):
return self._draft_token_ids
if (
self.draft_token_ids_copy_event is None
or self.draft_token_ids_cpu is None
or not self._has_draft_tokens
):
if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None:
return None
self._has_draft_tokens = False
@ -3600,9 +3599,7 @@ class GPUModelRunner(
mm_embed_inputs=mm_embed_inputs,
)
self._copy_draft_token_ids(
self._draft_token_ids, self.input_batch.num_reqs
)
self._copy_draft_token_ids(draft_token_ids, self.input_batch.num_reqs)
return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None: