mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-01 03:27:03 +08:00
fix bug about pre num_reqs maybe != current num_reqs
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
parent
78555c06a2
commit
0490418742
@ -614,9 +614,7 @@ class GPUModelRunner(
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
# Flag to track if valid draft tokens were copied this step.
|
||||
# Reset to False at step start, set True in _copy_draft_token_ids.
|
||||
self._has_draft_tokens: bool = False
|
||||
self._prev_copy_draft_num_reqs: int = 0
|
||||
|
||||
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
@ -3405,6 +3403,7 @@ class GPUModelRunner(
|
||||
):
|
||||
return
|
||||
|
||||
self._prev_copy_draft_num_reqs = num_reqs
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(self.draft_token_ids_copy_stream):
|
||||
self.draft_token_ids_copy_stream.wait_stream(default_stream) # type: ignore
|
||||
@ -3414,7 +3413,6 @@ class GPUModelRunner(
|
||||
draft_token_ids[:num_reqs], non_blocking=True
|
||||
)
|
||||
self.draft_token_ids_copy_event.record()
|
||||
self._has_draft_tokens = True
|
||||
|
||||
def _get_draft_token_ids_cpu(self) -> list[list[int]] | None:
|
||||
"""Get previously copied draft token ids from CPU.
|
||||
@ -3423,21 +3421,20 @@ class GPUModelRunner(
|
||||
for async scheduling with spec decode + penalty/bad_words.
|
||||
Returns None if no draft tokens were copied in previous step.
|
||||
"""
|
||||
if not self._has_draft_tokens:
|
||||
return None
|
||||
|
||||
if isinstance(self._draft_token_ids, list):
|
||||
return self._draft_token_ids
|
||||
|
||||
if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None:
|
||||
if (
|
||||
self.draft_token_ids_copy_event is None
|
||||
or self.draft_token_ids_cpu is None
|
||||
or not self._prev_copy_draft_num_reqs
|
||||
):
|
||||
return None
|
||||
|
||||
self._has_draft_tokens = False
|
||||
_prev_copy_draft_num_reqs = self._prev_copy_draft_num_reqs
|
||||
self._prev_copy_draft_num_reqs = 0
|
||||
self.draft_token_ids_copy_event.synchronize()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
if num_reqs == 0:
|
||||
return None
|
||||
return self.draft_token_ids_cpu[:num_reqs].tolist()
|
||||
return self.draft_token_ids_cpu[:_prev_copy_draft_num_reqs].tolist()
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user