mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:05:24 +08:00
[Perf] Avoid pageable HtoD transfer in MinTokensLogitsProcessor (#29826)
Signed-off-by: jthomson04 <jwillthomson19@gmail.com>
This commit is contained in:
parent
afb1e5b380
commit
1528e079e2
@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor):
|
||||
# Identify valid tokens using threshold comparison
|
||||
invalid_token_mask = probability_values < adjusted_min_p
|
||||
# Apply mask using boolean indexing
|
||||
logits[invalid_token_mask] = -float("inf")
|
||||
logits.masked_fill_(invalid_token_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
self._device_tensor([], torch.int32),
|
||||
)
|
||||
|
||||
self.neg_inf_tensor = torch.tensor(
|
||||
-float("inf"), dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""By censoring stop tokens, min-tokens can change the outcome
|
||||
of the argmax operation in greedy sampling."""
|
||||
@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if self.min_toks:
|
||||
# Inhibit EOS token for requests which have not reached min length
|
||||
logits[self.logits_slice] = -float("inf")
|
||||
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user