[BUG] Fix #20484. Support empty sequence in cuda penalty kernel (#20491)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
Vadim Gimpelson 2025-07-06 06:38:02 +04:00 committed by GitHub
parent c5ebe040ac
commit f73d02aadc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 0 deletions

View File

@ -59,6 +59,8 @@ void apply_repetition_penalties_(
int vocab_size = logits.size(-1);
int num_seqs = logits.size(0);
if (num_seqs == 0) return;
// Get number of SMs on the current device
int sms = 0;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,

View File

@ -75,3 +75,51 @@ def test_apply_repetition_penalties(
# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test for checking CUDA kernel")
@torch.inference_mode()
def test_apply_repetition_penalties_zero_seqs() -> None:
"""
Test the apply_repetition_penalties custom op with num_seqs=0
against a reference implementation.
"""
num_seqs = 0
vocab_size = 17
repetition_penalty = 1.05
dtype = torch.float32
seed = 0
current_platform.seed_everything(seed)
torch.set_default_device("cuda:0")
# Create test data
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)
# Create masks with some random tokens marked as repeated
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
# No tokens to mark as repeated since num_seqs=0
# Create repetition penalties tensor
repetition_penalties = torch.full((num_seqs, ),
repetition_penalty,
dtype=dtype)
# Run all three implementations
logits_torch = logits.clone()
logits_cuda = logits.clone()
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
repetition_penalties)
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
repetition_penalties)
# Compare all outputs to reference
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
# Test the operator by applying the opcheck utility
opcheck(torch.ops._C.apply_repetition_penalties_,
(logits.clone(), prompt_mask, output_mask, repetition_penalties))