[Perf] Avoid pageable HtoD transfer in MinTokensLogitsProcessor (#29826)

Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
jthomson04 2025-12-02 13:25:52 -08:00 committed by GitHub
parent afb1e5b380
commit 1528e079e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor):
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float("inf")
logits.masked_fill_(invalid_token_mask, -float("inf"))
return logits
@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor):
self._device_tensor([], torch.int32),
)
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor):
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits