mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 01:27:06 +08:00
fix: use num_draft_tokens to trim draft_token_ids_cpu
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
parent
8c0779a646
commit
33c63f263d
@ -966,7 +966,9 @@ class InputBatch:
|
||||
]
|
||||
|
||||
def update_async_spec_token_ids(
|
||||
self, draft_token_ids_cpu: list[list[int]] | None
|
||||
self,
|
||||
draft_token_ids_cpu: list[list[int]] | None,
|
||||
num_draft_tokens: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
In async scheduling case, update spec_token_ids in sampling metadata
|
||||
@ -985,11 +987,14 @@ class InputBatch:
|
||||
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||
if prev_index is None:
|
||||
continue
|
||||
assert prev_index < len(draft_token_ids_cpu)
|
||||
draft_ids = draft_token_ids_cpu[prev_index]
|
||||
if not draft_ids:
|
||||
continue
|
||||
assert index < len(spec_token_ids)
|
||||
|
||||
if num_draft_tokens is not None:
|
||||
scheduled_count = num_draft_tokens[index]
|
||||
assert scheduled_count <= len(draft_ids)
|
||||
draft_ids = draft_ids[:scheduled_count]
|
||||
spec_token_ids[index].clear()
|
||||
spec_token_ids[index].extend(draft_ids)
|
||||
|
||||
|
||||
@ -2582,7 +2582,10 @@ class GPUModelRunner(
|
||||
|
||||
# Update spec_token_ids with real draft tokens from previous step
|
||||
draft_token_ids_cpu = self._get_draft_token_ids_cpu()
|
||||
self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
|
||||
self.input_batch.update_async_spec_token_ids(
|
||||
draft_token_ids_cpu,
|
||||
num_draft_tokens=spec_decode_metadata.num_draft_tokens,
|
||||
)
|
||||
|
||||
sampler_output = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user