mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[BugFix] Fix logits repetition penalty cuda check (#22592)
This commit is contained in:
parent
afa5b7ca0b
commit
f919d4cb8f
@ -311,7 +311,7 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
|
|||||||
output_mask: A boolean tensor indicating which tokens appear in the output.
|
output_mask: A boolean tensor indicating which tokens appear in the output.
|
||||||
repetition_penalties: The repetition penalties of shape (num_seqs, ).
|
repetition_penalties: The repetition penalties of shape (num_seqs, ).
|
||||||
"""
|
"""
|
||||||
if current_platform.is_cuda() and logits.is_contiguous():
|
if logits.is_cuda and logits.is_contiguous():
|
||||||
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
|
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
|
||||||
repetition_penalties)
|
repetition_penalties)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user