[BugFix] Fix logits repetition penalty cuda check (#22592)

This commit is contained in:
Eugene Cheah 2025-08-10 22:52:31 -07:00 committed by GitHub
parent afa5b7ca0b
commit f919d4cb8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -311,7 +311,7 @@ def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor,
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
if current_platform.is_cuda() and logits.is_contiguous():
if logits.is_cuda and logits.is_contiguous():
apply_repetition_penalties_cuda(logits, prompt_mask, output_mask,
repetition_penalties)
else: