mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 08:56:02 +08:00
[Speculative Decoding] Move indices to device before filtering output (#10850)
Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
parent
9323a3153b
commit
f6084f6324
@ -120,6 +120,9 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
# move indices to device to avoid stream sync
|
||||
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||
return filtered_model_outputs, True
|
||||
@ -189,7 +192,7 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
@staticmethod
|
||||
def _filter_model_output(
|
||||
expanded_batch_outputs: List[SamplerOutput],
|
||||
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
||||
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
|
||||
"""
|
||||
Filters the model output to include only the specified sequence
|
||||
outputs. This method contracts the expanded batch output from the
|
||||
@ -199,8 +202,8 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
||||
Args:
|
||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||
batch from the model.
|
||||
output_indices_to_retain (List[int]): Indices of the model outputs
|
||||
to retain.
|
||||
output_indices_to_retain (torch.Tensor): Indices of the model
|
||||
outputs to retain.
|
||||
|
||||
Returns:
|
||||
List[SamplerOutput]: A list containing the filtered model
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user