fix bug about pre num_reqs maybe != current num_reqs

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhuhaoran 2025-12-12 01:08:23 +08:00
parent 78555c06a2
commit 0490418742

View File

@ -614,9 +614,7 @@ class GPUModelRunner(
device="cpu",
pin_memory=self.pin_memory,
)
# Flag to track if valid draft tokens were copied this step.
# Reset to False at step start, set True in _copy_draft_token_ids.
self._has_draft_tokens: bool = False
self._prev_copy_draft_num_reqs: int = 0
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
@ -3405,6 +3403,7 @@ class GPUModelRunner(
):
return
self._prev_copy_draft_num_reqs = num_reqs
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.draft_token_ids_copy_stream):
self.draft_token_ids_copy_stream.wait_stream(default_stream) # type: ignore
@ -3414,7 +3413,6 @@ class GPUModelRunner(
draft_token_ids[:num_reqs], non_blocking=True
)
self.draft_token_ids_copy_event.record()
self._has_draft_tokens = True
def _get_draft_token_ids_cpu(self) -> list[list[int]] | None:
"""Get previously copied draft token ids from CPU.
@ -3423,21 +3421,20 @@ class GPUModelRunner(
for async scheduling with spec decode + penalty/bad_words.
Returns None if no draft tokens were copied in previous step.
"""
if not self._has_draft_tokens:
return None
if isinstance(self._draft_token_ids, list):
return self._draft_token_ids
if self.draft_token_ids_copy_event is None or self.draft_token_ids_cpu is None:
if (
self.draft_token_ids_copy_event is None
or self.draft_token_ids_cpu is None
or not self._prev_copy_draft_num_reqs
):
return None
self._has_draft_tokens = False
_prev_copy_draft_num_reqs = self._prev_copy_draft_num_reqs
self._prev_copy_draft_num_reqs = 0
self.draft_token_ids_copy_event.synchronize()
num_reqs = self.input_batch.num_reqs
if num_reqs == 0:
return None
return self.draft_token_ids_cpu[:num_reqs].tolist()
return self.draft_token_ids_cpu[:_prev_copy_draft_num_reqs].tolist()
def propose_draft_token_ids(
self,