mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 07:25:01 +08:00
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@centml.ai>
This commit is contained in:
parent
c5ebe040ac
commit
f73d02aadc
@ -59,6 +59,8 @@ void apply_repetition_penalties_(
|
|||||||
int vocab_size = logits.size(-1);
|
int vocab_size = logits.size(-1);
|
||||||
int num_seqs = logits.size(0);
|
int num_seqs = logits.size(0);
|
||||||
|
|
||||||
|
if (num_seqs == 0) return;
|
||||||
|
|
||||||
// Get number of SMs on the current device
|
// Get number of SMs on the current device
|
||||||
int sms = 0;
|
int sms = 0;
|
||||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||||
|
|||||||
@ -75,3 +75,51 @@ def test_apply_repetition_penalties(
|
|||||||
# Test the operator by applying the opcheck utility
|
# Test the operator by applying the opcheck utility
|
||||||
opcheck(torch.ops._C.apply_repetition_penalties_,
|
opcheck(torch.ops._C.apply_repetition_penalties_,
|
||||||
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
|
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
|
reason="This test for checking CUDA kernel")
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_apply_repetition_penalties_zero_seqs() -> None:
|
||||||
|
"""
|
||||||
|
Test the apply_repetition_penalties custom op with num_seqs=0
|
||||||
|
against a reference implementation.
|
||||||
|
"""
|
||||||
|
num_seqs = 0
|
||||||
|
vocab_size = 17
|
||||||
|
repetition_penalty = 1.05
|
||||||
|
dtype = torch.float32
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
torch.set_default_device("cuda:0")
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
logits = torch.randn(num_seqs, vocab_size, dtype=dtype)
|
||||||
|
|
||||||
|
# Create masks with some random tokens marked as repeated
|
||||||
|
prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
|
||||||
|
output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool)
|
||||||
|
|
||||||
|
# No tokens to mark as repeated since num_seqs=0
|
||||||
|
|
||||||
|
# Create repetition penalties tensor
|
||||||
|
repetition_penalties = torch.full((num_seqs, ),
|
||||||
|
repetition_penalty,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
# Run all three implementations
|
||||||
|
logits_torch = logits.clone()
|
||||||
|
logits_cuda = logits.clone()
|
||||||
|
|
||||||
|
apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask,
|
||||||
|
repetition_penalties)
|
||||||
|
apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask,
|
||||||
|
repetition_penalties)
|
||||||
|
|
||||||
|
# Compare all outputs to reference
|
||||||
|
torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
# Test the operator by applying the opcheck utility
|
||||||
|
opcheck(torch.ops._C.apply_repetition_penalties_,
|
||||||
|
(logits.clone(), prompt_mask, output_mask, repetition_penalties))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user