mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:55:02 +08:00
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
parent
c5ebe040ac
commit
f73d02aadc
@ -59,6 +59,8 @@ void apply_repetition_penalties_(
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
|
||||
if (num_seqs == 0) return;
|
||||
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
|
||||
@ -75,3 +75,51 @@ def test_apply_repetition_penalties(
|
||||
# Test the operator by applying the opcheck utility
|
||||
opcheck(torch.ops._C.apply_repetition_penalties_,
|
||||
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="This test for checking CUDA kernel")
|
||||
@torch.inference_mode()
|
||||
def test_apply_repetition_penalties_zero_seqs() -> None:
|
||||
"""
|
||||
Test the apply_repetition_penalties custom op with num_seqs=0
|
||||
against a reference implementation.
|
||||
"""
|
||||
num_seqs = 0
|
||||
vocab_size = 17
|
||||
repetition_penalty = 1.05
|
||||
dtype = torch.float32
|
||||
seed = 0
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
# Create test data
|
||||
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)
|
||||
|
||||
# Create masks with some random tokens marked as repeated
|
||||
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
|
||||
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
|
||||
|
||||
# No tokens to mark as repeated since num_seqs=0
|
||||
|
||||
# Create repetition penalties tensor
|
||||
repetition_penalties = torch.full((num_seqs, ),
|
||||
repetition_penalty,
|
||||
dtype=dtype)
|
||||
|
||||
# Run all three implementations
|
||||
logits_torch = logits.clone()
|
||||
logits_cuda = logits.clone()
|
||||
|
||||
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
|
||||
repetition_penalties)
|
||||
|
||||
# Compare all outputs to reference
|
||||
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Test the operator by applying the opcheck utility
|
||||
opcheck(torch.ops._C.apply_repetition_penalties_,
|
||||
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user