[Speculative Decoding] Move indices to device before filtering output (#10850)

Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
Yang Zheng 2024-12-03 17:01:39 +08:00 committed by GitHub
parent 9323a3153b
commit f6084f6324
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -120,6 +120,9 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
indices_of_seq_with_bonus_tokens) indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output) model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output( filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens) model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True return filtered_model_outputs, True
@ -189,7 +192,7 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
@staticmethod @staticmethod
def _filter_model_output( def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput], expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]: output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
""" """
Filters the model output to include only the specified sequence Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the outputs. This method contracts the expanded batch output from the
@ -199,8 +202,8 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
Args: Args:
expanded_batch_output (List[SamplerOutput]): The expanded output expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model. batch from the model.
output_indices_to_retain (List[int]): Indices of the model outputs output_indices_to_retain (torch.Tensor): Indices of the model
to retain. outputs to retain.
Returns: Returns:
List[SamplerOutput]: A list containing the filtered model List[SamplerOutput]: A list containing the filtered model