mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 19:55:41 +08:00
fix bug
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
parent
461e7b3e32
commit
78555c06a2
@ -952,7 +952,18 @@ class InputBatch:
|
|||||||
sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
||||||
# Replace placeholder token id(s) with actual sampled id(s).
|
# Replace placeholder token id(s) with actual sampled id(s).
|
||||||
if sampled_ids := sampled_token_ids[prev_index]:
|
if sampled_ids := sampled_token_ids[prev_index]:
|
||||||
req_output_token_ids[-len(sampled_ids) :] = sampled_ids
|
num_placeholders = 0
|
||||||
|
for t in reversed(req_output_token_ids):
|
||||||
|
if t == -1:
|
||||||
|
num_placeholders += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if num_placeholders == 0:
|
||||||
|
continue
|
||||||
|
assert num_placeholders <= len(sampled_ids)
|
||||||
|
req_output_token_ids[-num_placeholders:] = sampled_ids[
|
||||||
|
:num_placeholders
|
||||||
|
]
|
||||||
|
|
||||||
def update_async_spec_token_ids(
|
def update_async_spec_token_ids(
|
||||||
self, draft_token_ids_cpu: list[list[int]] | None
|
self, draft_token_ids_cpu: list[list[int]] | None
|
||||||
@ -971,13 +982,16 @@ class InputBatch:
|
|||||||
return
|
return
|
||||||
|
|
||||||
for index, req_id in enumerate(self.req_ids):
|
for index, req_id in enumerate(self.req_ids):
|
||||||
prev_index = self.prev_req_id_to_index.get(req_id, default=None)
|
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||||
if prev_index is None:
|
if prev_index is None:
|
||||||
continue
|
continue
|
||||||
assert prev_index < len(draft_token_ids_cpu)
|
assert prev_index < len(draft_token_ids_cpu)
|
||||||
draft_ids = draft_token_ids_cpu[prev_index]
|
draft_ids = draft_token_ids_cpu[prev_index]
|
||||||
|
if not draft_ids:
|
||||||
|
continue
|
||||||
assert index < len(spec_token_ids)
|
assert index < len(spec_token_ids)
|
||||||
spec_token_ids[index] = draft_ids
|
spec_token_ids[index].clear()
|
||||||
|
spec_token_ids[index].extend(draft_ids)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_reqs(self) -> int:
|
def num_reqs(self) -> int:
|
||||||
|
|||||||
@ -3423,14 +3423,13 @@ class GPUModelRunner(
|
|||||||
for async scheduling with spec decode + penalty/bad_words.
|
for async scheduling with spec decode + penalty/bad_words.
|
||||||
Returns None if no draft tokens were copied in previous step.
|
Returns None if no draft tokens were copied in previous step.
|
||||||
"""
|
"""
|
||||||
|
if not self._has_draft_tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
if isinstance(self._draft_token_ids, list):
|
if isinstance(self._draft_token_ids, list):
|
||||||
return self._draft_token_ids
|
return self._draft_token_ids
|
||||||
|
|
||||||
if (
|
if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None:
|
||||||
self.draft_token_ids_copy_event is None
|
|
||||||
or self.draft_token_ids_cpu is None
|
|
||||||
or not self._has_draft_tokens
|
|
||||||
):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self._has_draft_tokens = False
|
self._has_draft_tokens = False
|
||||||
@ -3600,9 +3599,7 @@ class GPUModelRunner(
|
|||||||
mm_embed_inputs=mm_embed_inputs,
|
mm_embed_inputs=mm_embed_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._copy_draft_token_ids(
|
self._copy_draft_token_ids(draft_token_ids, self.input_batch.num_reqs)
|
||||||
self._draft_token_ids, self.input_batch.num_reqs
|
|
||||||
)
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def update_config(self, overrides: dict[str, Any]) -> None:
|
def update_config(self, overrides: dict[str, Any]) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user