Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhuhaoran 2025-12-12 00:02:47 +08:00
parent 461e7b3e32
commit 78555c06a2
2 changed files with 22 additions and 11 deletions

View File

@ -952,7 +952,18 @@ class InputBatch:
sampled_token_ids = self.sampled_token_ids_cpu.tolist() sampled_token_ids = self.sampled_token_ids_cpu.tolist()
# Replace placeholder token id(s) with actual sampled id(s). # Replace placeholder token id(s) with actual sampled id(s).
if sampled_ids := sampled_token_ids[prev_index]: if sampled_ids := sampled_token_ids[prev_index]:
req_output_token_ids[-len(sampled_ids) :] = sampled_ids num_placeholders = 0
for t in reversed(req_output_token_ids):
if t == -1:
num_placeholders += 1
else:
break
if num_placeholders == 0:
continue
assert num_placeholders <= len(sampled_ids)
req_output_token_ids[-num_placeholders:] = sampled_ids[
:num_placeholders
]
def update_async_spec_token_ids( def update_async_spec_token_ids(
self, draft_token_ids_cpu: list[list[int]] | None self, draft_token_ids_cpu: list[list[int]] | None
@ -971,13 +982,16 @@ class InputBatch:
return return
for index, req_id in enumerate(self.req_ids): for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id, default=None) prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None: if prev_index is None:
continue continue
assert prev_index < len(draft_token_ids_cpu) assert prev_index < len(draft_token_ids_cpu)
draft_ids = draft_token_ids_cpu[prev_index] draft_ids = draft_token_ids_cpu[prev_index]
if not draft_ids:
continue
assert index < len(spec_token_ids) assert index < len(spec_token_ids)
spec_token_ids[index] = draft_ids spec_token_ids[index].clear()
spec_token_ids[index].extend(draft_ids)
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:

View File

@ -3423,14 +3423,13 @@ class GPUModelRunner(
for async scheduling with spec decode + penalty/bad_words. for async scheduling with spec decode + penalty/bad_words.
Returns None if no draft tokens were copied in previous step. Returns None if no draft tokens were copied in previous step.
""" """
if not self._has_draft_tokens:
return None
if isinstance(self._draft_token_ids, list): if isinstance(self._draft_token_ids, list):
return self._draft_token_ids return self._draft_token_ids
if ( if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None:
self.draft_token_ids_copy_event is None
or self.draft_token_ids_cpu is None
or not self._has_draft_tokens
):
return None return None
self._has_draft_tokens = False self._has_draft_tokens = False
@ -3600,9 +3599,7 @@ class GPUModelRunner(
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
) )
self._copy_draft_token_ids( self._copy_draft_token_ids(draft_token_ids, self.input_batch.num_reqs)
self._draft_token_ids, self.input_batch.num_reqs
)
return draft_token_ids return draft_token_ids
def update_config(self, overrides: dict[str, Any]) -> None: def update_config(self, overrides: dict[str, Any]) -> None: