mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:45:01 +08:00
[Speculative Decoding] Move indices to device before filtering output (#10850)
Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
parent
9323a3153b
commit
f6084f6324
@ -120,6 +120,9 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
|||||||
indices_of_seq_with_bonus_tokens)
|
indices_of_seq_with_bonus_tokens)
|
||||||
model_outputs.append(model_output)
|
model_outputs.append(model_output)
|
||||||
|
|
||||||
|
# move indices to device to avoid stream sync
|
||||||
|
indices_of_seq_with_bonus_tokens = torch.tensor(
|
||||||
|
indices_of_seq_with_bonus_tokens, device=self.device)
|
||||||
filtered_model_outputs = self._filter_model_output(
|
filtered_model_outputs = self._filter_model_output(
|
||||||
model_outputs, indices_of_seq_with_bonus_tokens)
|
model_outputs, indices_of_seq_with_bonus_tokens)
|
||||||
return filtered_model_outputs, True
|
return filtered_model_outputs, True
|
||||||
@ -189,7 +192,7 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_model_output(
|
def _filter_model_output(
|
||||||
expanded_batch_outputs: List[SamplerOutput],
|
expanded_batch_outputs: List[SamplerOutput],
|
||||||
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
|
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
|
||||||
"""
|
"""
|
||||||
Filters the model output to include only the specified sequence
|
Filters the model output to include only the specified sequence
|
||||||
outputs. This method contracts the expanded batch output from the
|
outputs. This method contracts the expanded batch output from the
|
||||||
@ -199,8 +202,8 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
|
|||||||
Args:
|
Args:
|
||||||
expanded_batch_output (List[SamplerOutput]): The expanded output
|
expanded_batch_output (List[SamplerOutput]): The expanded output
|
||||||
batch from the model.
|
batch from the model.
|
||||||
output_indices_to_retain (List[int]): Indices of the model outputs
|
output_indices_to_retain (torch.Tensor): Indices of the model
|
||||||
to retain.
|
outputs to retain.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[SamplerOutput]: A list containing the filtered model
|
List[SamplerOutput]: A list containing the filtered model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user