mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-13 20:28:01 +08:00
fix: use num_draft_tokens to trim draft_token_ids_cpu
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
parent
8c0779a646
commit
33c63f263d
@ -966,7 +966,9 @@ class InputBatch:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def update_async_spec_token_ids(
|
def update_async_spec_token_ids(
|
||||||
self, draft_token_ids_cpu: list[list[int]] | None
|
self,
|
||||||
|
draft_token_ids_cpu: list[list[int]] | None,
|
||||||
|
num_draft_tokens: list[int] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
In async scheduling case, update spec_token_ids in sampling metadata
|
In async scheduling case, update spec_token_ids in sampling metadata
|
||||||
@ -985,11 +987,14 @@ class InputBatch:
|
|||||||
prev_index = self.prev_req_id_to_index.get(req_id)
|
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||||
if prev_index is None:
|
if prev_index is None:
|
||||||
continue
|
continue
|
||||||
assert prev_index < len(draft_token_ids_cpu)
|
|
||||||
draft_ids = draft_token_ids_cpu[prev_index]
|
draft_ids = draft_token_ids_cpu[prev_index]
|
||||||
if not draft_ids:
|
if not draft_ids:
|
||||||
continue
|
continue
|
||||||
assert index < len(spec_token_ids)
|
|
||||||
|
if num_draft_tokens is not None:
|
||||||
|
scheduled_count = num_draft_tokens[index]
|
||||||
|
assert scheduled_count <= len(draft_ids)
|
||||||
|
draft_ids = draft_ids[:scheduled_count]
|
||||||
spec_token_ids[index].clear()
|
spec_token_ids[index].clear()
|
||||||
spec_token_ids[index].extend(draft_ids)
|
spec_token_ids[index].extend(draft_ids)
|
||||||
|
|
||||||
|
|||||||
@ -2582,7 +2582,10 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
# Update spec_token_ids with real draft tokens from previous step
|
# Update spec_token_ids with real draft tokens from previous step
|
||||||
draft_token_ids_cpu = self._get_draft_token_ids_cpu()
|
draft_token_ids_cpu = self._get_draft_token_ids_cpu()
|
||||||
self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
|
self.input_batch.update_async_spec_token_ids(
|
||||||
|
draft_token_ids_cpu,
|
||||||
|
num_draft_tokens=spec_decode_metadata.num_draft_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
sampler_output = self.rejection_sampler(
|
sampler_output = self.rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user