diff --git a/CMakeLists.txt b/CMakeLists.txt index 87aa23c080f5..f11d28590b28 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,6 +242,7 @@ set(VLLM_EXT_SRC "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" diff --git a/csrc/ops.h b/csrc/ops.h index 7044b4588b81..297f32b4a2a0 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void apply_repetition_penalties_(torch::Tensor& logits, + const torch::Tensor& prompt_mask, + const torch::Tensor& output_mask, + const torch::Tensor& repetition_penalties); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); diff --git a/csrc/sampler.cu b/csrc/sampler.cu new file mode 100644 index 000000000000..ee5793dda0ef --- /dev/null +++ b/csrc/sampler.cu @@ -0,0 +1,86 @@ +#include "dispatch_utils.h" + +#include +#include + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +template +__global__ void apply_repetition_penalties_kernel( + scalar_t* __restrict__ logits, // [num_seqs, vocab_size] + const bool* __restrict__ prompt_mask, // [num_seqs, vocab_size] + const bool* __restrict__ output_mask, // [num_seqs, vocab_size] + const scalar_t* __restrict__ repetition_penalties, // [num_seqs] + const int num_seqs, const int vocab_size, const int tile_size) { + // Each block handles one sequence and a tile of vocab + const int seq_idx = blockIdx.x; + if (seq_idx >= num_seqs) return; + + const int tile_start = blockIdx.y * tile_size; + const int tile_end = min(tile_start + tile_size, vocab_size); + + // Load repetition penalty for this sequence + const scalar_t penalty = repetition_penalties[seq_idx]; + + // Each thread processes multiple vocab items within the tile + for (int vocab_idx = tile_start + threadIdx.x; vocab_idx < tile_end; + vocab_idx += blockDim.x) { + const int64_t idx = static_cast(seq_idx) * vocab_size + vocab_idx; + const bool is_repeated = prompt_mask[idx] || output_mask[idx]; + if (is_repeated) { + scalar_t logit = logits[idx]; + if (logit > 0) { + logits[idx] = logit / penalty; + } else { + logits[idx] = logit * penalty; + } + } + } +} + +} // namespace vllm + +void apply_repetition_penalties_( + torch::Tensor& logits, // [num_seqs, vocab_size], in-place + const torch::Tensor& prompt_mask, // [num_seqs, vocab_size] + const torch::Tensor& output_mask, // [num_seqs, vocab_size] + const torch::Tensor& repetition_penalties) { // [num_seqs] + TORCH_CHECK(logits.is_contiguous()); + TORCH_CHECK(prompt_mask.is_contiguous()); + TORCH_CHECK(output_mask.is_contiguous()); + TORCH_CHECK(repetition_penalties.is_contiguous()); + + int vocab_size = logits.size(-1); + int num_seqs = logits.size(0); + + // Get number of SMs on the current device + int sms = 0; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, + logits.get_device()); + + // Compute tile_num and tile_size + int tile_num = + std::min(vocab_size, std::max(1, (sms + num_seqs - 1) / num_seqs)); + int tile_size = (vocab_size + tile_num - 1) / tile_num; + + // Each block handles one sequence and a tile of vocab + dim3 grid(num_seqs, tile_num); + dim3 block(std::min(tile_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(logits)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + logits.scalar_type(), "apply_repetition_penalties_kernel", [&] { + vllm::apply_repetition_penalties_kernel + <<>>( + logits.data_ptr(), prompt_mask.data_ptr(), + output_mask.data_ptr(), + repetition_penalties.data_ptr(), num_seqs, vocab_size, + tile_size); + }); +} \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 371894c56a79..3fffaf290ad3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -170,6 +170,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "float epsilon) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + // Apply repetition penalties to logits in-place + ops.def( + "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " + "Tensor output_mask, Tensor repetition_penalties) -> ()"); + ops.impl("apply_repetition_penalties_", torch::kCUDA, + &apply_repetition_penalties_); + // Layernorm-quant // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/kernels/test_apply_repetition_penalties.py b/tests/kernels/test_apply_repetition_penalties.py new file mode 100644 index 000000000000..9115949a1651 --- /dev/null +++ b/tests/kernels/test_apply_repetition_penalties.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm._custom_ops import (apply_repetition_penalties_cuda, + apply_repetition_penalties_torch) +from vllm.platforms import current_platform + +NUM_SEQS = [1, 2, 3, 4, 8, 13, 17, 32, 37, 256, 1023, 1024, 1025] +# [stress, stress, stress, Qwen, llama 4] +VOCAB_SIZES = [17, 256, 1019, 151936, 202048] +REPETITION_PENALTY_VALUES = [1.05] +SEEDS = [0] +DTYPES = [torch.float32, torch.float16] + + +@pytest.mark.parametrize("num_seqs", NUM_SEQS) +@pytest.mark.parametrize("vocab_size", VOCAB_SIZES) +@pytest.mark.parametrize("repetition_penalty", REPETITION_PENALTY_VALUES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test for checking CUDA kernel") +@torch.inference_mode() +def test_apply_repetition_penalties( + num_seqs: int, + vocab_size: int, + repetition_penalty: float, + dtype: torch.dtype, + seed: int, +) -> None: + """ + Test the apply_repetition_penalties custom op + against a reference implementation. + """ + current_platform.seed_everything(seed) + torch.set_default_device("cuda:0") + + # Create test data + logits = torch.randn(num_seqs, vocab_size, dtype=dtype) + + # Create masks with some random tokens marked as repeated + prompt_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) + output_mask = torch.zeros(num_seqs, vocab_size, dtype=torch.bool) + + # Mark some tokens as repeated in prompt and output + prompt_indices = torch.randint(0, vocab_size, + (num_seqs, max(1, vocab_size // 200))) + output_indices = torch.randint(0, vocab_size, + (num_seqs, max(1, vocab_size // 200))) + + for i in range(num_seqs): + prompt_mask[i, prompt_indices[i]] = True + output_mask[i, output_indices[i]] = True + + # Create repetition penalties tensor + repetition_penalties = torch.full((num_seqs, ), + repetition_penalty, + dtype=dtype) + + # Run all three implementations + logits_torch = logits.clone() + logits_cuda = logits.clone() + + apply_repetition_penalties_torch(logits_torch, prompt_mask, output_mask, + repetition_penalties) + apply_repetition_penalties_cuda(logits_cuda, prompt_mask, output_mask, + repetition_penalties) + + # Compare all outputs to reference + torch.testing.assert_close(logits_torch, logits_cuda, rtol=1e-3, atol=1e-3) + + # Test the operator by applying the opcheck utility + opcheck(torch.ops._C.apply_repetition_penalties_, + (logits.clone(), prompt_mask, output_mask, repetition_penalties)) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 008a7aa94939..3282edf410b6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -282,6 +282,45 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def apply_repetition_penalties_torch( + logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( + 1, logits.size(1)) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, + 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scaling = torch.where(logits > 0, 1.0 / penalties, penalties) + logits *= scaling + + +def apply_repetition_penalties_cuda( + logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, repetition_penalties: torch.Tensor) -> None: + torch.ops._C.apply_repetition_penalties_(logits, prompt_mask, output_mask, + repetition_penalties) + + +def apply_repetition_penalties(logits: torch.Tensor, prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + repetition_penalties: torch.Tensor) -> None: + """Apply repetition penalties to logits in-place. + + Args: + logits: The logits tensor of shape [num_seqs, vocab_size]. + prompt_mask: A boolean tensor indicating which tokens appear in the prompt. + output_mask: A boolean tensor indicating which tokens appear in the output. + repetition_penalties: The repetition penalties of shape (num_seqs, ). + """ + if current_platform.is_cuda() and logits.is_contiguous(): + apply_repetition_penalties_cuda(logits, prompt_mask, output_mask, + repetition_penalties) + else: + apply_repetition_penalties_torch(logits, prompt_mask, output_mask, + repetition_penalties) + + def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index d97d84238697..41b5253dca04 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -50,16 +50,11 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, vocab_size, num_seqs) output_bin_counts, output_mask = get_token_bin_counts_and_mask( output_tokens_tensor, vocab_size, num_seqs) - repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( - 1, vocab_size) - # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - penalties = torch.where(prompt_mask | output_mask, repetition_penalties, - 1.0) - - # If logits are positive, divide by penalty, otherwise multiply by penalty. - scaling = torch.where(logits > 0, 1.0 / penalties, penalties) - logits *= scaling + # Apply repetition penalties as a custom op + from vllm._custom_ops import apply_repetition_penalties + apply_repetition_penalties(logits, prompt_mask, output_mask, + repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details