[Speculative Decoding] Move indices to device before filtering output (#10850)

Co-authored-by: Yang Zheng(SW)(Alex) <you@example.com>
This commit is contained in:
Yang Zheng 2024-12-03 17:01:39 +08:00 committed by GitHub
parent 9323a3153b
commit f6084f6324
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -120,6 +120,9 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True
@ -189,7 +192,7 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
@staticmethod
def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]:
output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
"""
Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the
@ -199,8 +202,8 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
Args:
expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model.
output_indices_to_retain (List[int]): Indices of the model outputs
to retain.
output_indices_to_retain (torch.Tensor): Indices of the model
outputs to retain.
Returns:
List[SamplerOutput]: A list containing the filtered model