From 1dcafb3dea62f556011be4df6f71769aa7260561 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 28 Nov 2025 17:53:17 -0800 Subject: [PATCH] [Model Runner V2] Support penalties using bin counts (#29703) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 15 +++ vllm/v1/worker/gpu/model_runner.py | 9 +- vllm/v1/worker/gpu/penalties.py | 85 ++++++++++++++ vllm/v1/worker/gpu/sampler.py | 3 + vllm/v1/worker/gpu/states.py | 182 +++++++++++++++++++++++++++-- 5 files changed, 280 insertions(+), 14 deletions(-) create mode 100644 vllm/v1/worker/gpu/penalties.py diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 2a7048ae3c0e0..43fd53d3acaae 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -341,6 +341,8 @@ def _post_update_kernel( idx_mapping_ptr, num_computed_tokens_ptr, last_sampled_tokens_ptr, + output_bin_counts_ptr, + output_bin_counts_stride, sampled_tokens_ptr, sampled_tokens_stride, num_sampled_ptr, @@ -357,6 +359,15 @@ def _post_update_kernel( ) tl.store(last_sampled_tokens_ptr + req_state_idx, token_id) + for i in range(num_sampled): + token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i) + token_ptr = ( + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id + ) + count = tl.load(token_ptr) + count += 1 + tl.store(token_ptr, count) + query_start = tl.load(query_start_loc_ptr + req_id) query_end = tl.load(query_start_loc_ptr + req_id + 1) query_len = query_end - query_start @@ -374,6 +385,8 @@ def post_update( num_computed_tokens: torch.Tensor, # [max_num_reqs] last_sampled_tokens: torch.Tensor, + # [max_num_reqs, vocab_size] + output_bin_counts: torch.Tensor, # [num_reqs, num_speculative_steps + 1] sampled_tokens: torch.Tensor, # [num_reqs] @@ -388,6 +401,8 @@ def post_update( idx_mapping, num_computed_tokens, last_sampled_tokens, + output_bin_counts, + output_bin_counts.stride(0), sampled_tokens, sampled_tokens.stride(0), num_sampled, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 0c9fdd0077f4a..9ba234544421d 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -512,7 +512,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_np, num_scheduled_tokens, query_start_loc_np, - self.req_states.prefill_token_ids, + self.req_states.prefill_token_ids.np, self.req_states.num_computed_prefill_tokens, self.input_buffers.input_ids.np, ) @@ -681,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Handle chunked prompts. pos_after_step = computed_prefill + input_batch.num_scheduled_tokens is_prompt_chunked = pos_after_step < prompt_lens - prefill_token_ids = self.req_states.prefill_token_ids + prefill_token_ids = self.req_states.prefill_token_ids.np query_start_loc = self.input_buffers.query_start_loc.np for i, req_id in enumerate(input_batch.req_ids): if not needs_prompt_logprobs[i]: @@ -756,6 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch.idx_mapping, self.req_states.num_computed_tokens, self.req_states.last_sampled_tokens, + self.req_states.output_bin_counts, sampled_tokens, num_sampled, num_rejected, @@ -785,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): idx_mapping_np = input_batch.idx_mapping_np with async_barrier(self.spec_decode_event): self.input_buffers.next_prefill_tokens.np[:num_reqs] = ( - self.req_states.prefill_token_ids[ + self.req_states.prefill_token_ids.np[ idx_mapping_np, self.req_states.num_computed_prefill_tokens[idx_mapping_np], ] @@ -896,7 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # barrier to avoid race conditions. pos = input_batch.positions[input_batch.logits_indices] sampling_metadata = self.req_states.make_sampling_metadata( - input_batch.idx_mapping_np, pos + input_batch.idx_mapping, input_batch.idx_mapping_np, pos ) if input_batch.num_draft_tokens > 0: sampling_metadata = self.req_states.expand_sampling_metadata( diff --git a/vllm/v1/worker/gpu/penalties.py b/vllm/v1/worker/gpu/penalties.py new file mode 100644 index 0000000000000..f87ee01718cd6 --- /dev/null +++ b/vllm/v1/worker/gpu/penalties.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.states import SamplingMetadata + + +@triton.jit +def _penalties_kernel( + logits_ptr, + logits_stride, + repetition_penalty_ptr, + frequency_penalty_ptr, + presence_penalty_ptr, + idx_mapping_ptr, + prompt_bin_counts_ptr, + prompt_bin_counts_stride, + output_bin_counts_ptr, + output_bin_counts_stride, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) + freq_penalty = tl.load(frequency_penalty_ptr + batch_idx) + pres_penalty = tl.load(presence_penalty_ptr + batch_idx) + + use_rep_penalty = rep_penalty != 1.0 + use_freq_penalty = freq_penalty != 0.0 + use_pres_penalty = pres_penalty != 0.0 + if not (use_rep_penalty or use_freq_penalty or use_pres_penalty): + # No penalties to apply. Early return. + return + + block_idx = tl.program_id(1) + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) + logits = logits.to(tl.float32) + + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + output_bin_counts = tl.load( + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, + mask=mask, + ) + + # Apply repetition penalties. + if use_rep_penalty: + prompt_bin_counts = tl.load( + prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block, + mask=mask, + ) + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + scale = tl.where(logits > 0, 1.0 / scale, scale) + logits *= scale + + # Apply frequency penalties. + logits -= freq_penalty * output_bin_counts + # Apply presence penalties. + logits -= pres_penalty * (output_bin_counts > 0) + # Store back to logits. + tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) + + +def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None: + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 8192 + num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) + _penalties_kernel[(num_reqs, num_blocks)]( + logits, + logits.stride(0), + sampling_metadata.repetition_penalty, + sampling_metadata.frequency_penalty, + sampling_metadata.presence_penalty, + sampling_metadata.idx_mapping, + sampling_metadata.prompt_bin_counts, + sampling_metadata.prompt_bin_counts.stride(0), + sampling_metadata.output_bin_counts, + sampling_metadata.output_bin_counts.stride(0), + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index d8676079ab951..6e0d6150a9669 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -8,6 +8,7 @@ from vllm.config.model import LogprobsMode from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p +from vllm.v1.worker.gpu.penalties import apply_penalties from vllm.v1.worker.gpu.states import SamplingMetadata @@ -65,6 +66,8 @@ class Sampler: logits = apply_top_k_top_p( logits, sampling_metadata.top_k, sampling_metadata.top_p ) + # Apply penalties in place. + apply_penalties(logits, sampling_metadata) sampled = gumbel_sample( logits, diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 513d45d95d7cd..64874b72e60cf 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -8,6 +8,8 @@ import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton +from vllm.utils.platform_utils import is_uva_available +from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor from vllm.v1.outputs import LogprobsTensors from vllm.v1.utils import CpuGpuBuffer @@ -23,12 +25,21 @@ class SamplingMetadata: top_p: torch.Tensor | None top_k: torch.Tensor | None + repetition_penalty: torch.Tensor + frequency_penalty: torch.Tensor + presence_penalty: torch.Tensor + seeds: torch.Tensor pos: torch.Tensor # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: int | None + # For penalties + idx_mapping: torch.Tensor + prompt_bin_counts: torch.Tensor + output_bin_counts: torch.Tensor + @classmethod def make_dummy( cls, @@ -44,17 +55,35 @@ class SamplingMetadata: # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) top_p = None top_k = None + # NOTE(woosuk): We must set penalties to their default values to make sure + # the penalties kernel does not touch the placeholder bin_counts tensors. + repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) + frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) + presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device) seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) max_num_logprobs = 20 + idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) + # NOTE(woosuk): These are placeholder tensors to avoid None checks in the + # penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton + # specialization and re-compilation at runtime. + prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) + output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device) + return cls( temperature=temperature, top_p=top_p, top_k=top_k, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, seeds=seeds, pos=pos, max_num_logprobs=max_num_logprobs, + idx_mapping=idx_mapping, + prompt_bin_counts=prompt_bin_counts, + output_bin_counts=output_bin_counts, ) @@ -83,9 +112,10 @@ class RequestState: self.extra_data: dict[str, ExtraData] = {} self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32) - self.prefill_token_ids = np.zeros( - (self.max_num_reqs, self.max_model_len), - dtype=np.int32, + # NOTE(woosuk): This tensor can be extremely large (e.g., several GBs) + # depending on the configured max_num_reqs and max_model_len. + self.prefill_token_ids = UvaBuffer( + self.max_num_reqs, self.max_model_len, dtype=torch.int32 ) self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32) @@ -119,6 +149,9 @@ class RequestState: self.temperature = self._make_param(self.max_num_reqs, torch.float32) self.top_p = self._make_param(self.max_num_reqs, torch.float32) self.top_k = self._make_param(self.max_num_reqs, torch.int32) + self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32) + self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32) + self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32) self.seeds = self._make_param(self.max_num_reqs, torch.int64) self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) @@ -126,6 +159,16 @@ class RequestState: self.num_logprobs.fill(-1) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) + # Statistics for penalties. + # TODO(woosuk): These tensors are rarely used but can be extremely large. + # Optimize the memory usage. + self.prompt_bin_counts = torch.zeros( + self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device + ) + self.output_bin_counts = torch.zeros( + self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device + ) + def _make_param(self, size: int, dtype: torch.dtype) -> "Param": return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory) @@ -159,7 +202,7 @@ class RequestState: f"prefill_len {prefill_len} < prompt_len {prompt_len}" ) self.prefill_len.np[req_idx] = prefill_len - self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids + self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids self.num_computed_prefill_tokens[req_idx] = num_computed_tokens # FIXME(woosuk): This triggers a GPU operation whenever adding a new request. @@ -178,6 +221,18 @@ class RequestState: else: top_k = self.vocab_size self.top_k.np[req_idx] = top_k + self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty + self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty + self.presence_penalty.np[req_idx] = sampling_params.presence_penalty + + if use_penalty(sampling_params): + bincount( + self.prefill_token_ids.gpu[req_idx], + prefill_len, + prompt_len, + self.prompt_bin_counts[req_idx], + self.output_bin_counts[req_idx], + ) if sampling_params.seed is not None: seed = sampling_params.seed @@ -206,24 +261,32 @@ class RequestState: def make_sampling_metadata( self, - idx_mapping: np.ndarray, + idx_mapping: torch.Tensor, + idx_mapping_np: np.ndarray, pos: torch.Tensor, ) -> SamplingMetadata: - temperature = self.temperature.np[idx_mapping] + temperature = self.temperature.np[idx_mapping_np] temperature = self.temperature.copy_np_to_gpu(temperature) - top_p = self.top_p.np[idx_mapping] + top_p = self.top_p.np[idx_mapping_np] no_top_p = np.all(top_p == 1.0) top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None - top_k = self.top_k.np[idx_mapping] + top_k = self.top_k.np[idx_mapping_np] no_top_k = np.all(top_k == self.vocab_size) top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None - seeds = self.seeds.np[idx_mapping] + rep_penalty = self.repetition_penalty.np[idx_mapping_np] + rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty) + freq_penalty = self.frequency_penalty.np[idx_mapping_np] + freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty) + pres_penalty = self.presence_penalty.np[idx_mapping_np] + pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty) + + seeds = self.seeds.np[idx_mapping_np] seeds = self.seeds.copy_np_to_gpu(seeds) - num_logprobs = self.num_logprobs[idx_mapping] + num_logprobs = self.num_logprobs[idx_mapping_np] max_num_logprobs: int | None = int(np.max(num_logprobs)) if max_num_logprobs == -1: max_num_logprobs = None @@ -232,9 +295,15 @@ class RequestState: temperature=temperature, top_p=top_p, top_k=top_k, + repetition_penalty=rep_penalty, + frequency_penalty=freq_penalty, + presence_penalty=pres_penalty, seeds=seeds, pos=pos, max_num_logprobs=max_num_logprobs, + idx_mapping=idx_mapping, + prompt_bin_counts=self.prompt_bin_counts, + output_bin_counts=self.output_bin_counts, ) def expand_sampling_metadata( @@ -294,6 +363,14 @@ class ExtraData: in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) +class UvaBuffer: + def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype): + assert is_uva_available() + self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True) + self.np = self.cpu.numpy() + self.gpu = get_cuda_view_from_cpu_tensor(self.cpu) + + # NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None. @triton.jit def _expand_sampling_metadata_kernel( @@ -304,6 +381,12 @@ def _expand_sampling_metadata_kernel( top_k_ptr, expanded_top_k_ptr, seeds_ptr, + rep_penalty_ptr, + expanded_rep_penalty_ptr, + freq_penalty_ptr, + expanded_freq_penalty_ptr, + pres_penalty_ptr, + expanded_pres_penalty_ptr, expanded_seeds_ptr, cu_num_logits_ptr, BLOCK_SIZE: tl.constexpr, @@ -327,6 +410,15 @@ def _expand_sampling_metadata_kernel( top_k = tl.load(top_k_ptr + req_idx) tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) + rep_penalty = tl.load(rep_penalty_ptr + req_idx) + tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask) + + freq_penalty = tl.load(freq_penalty_ptr + req_idx) + tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask) + + pres_penalty = tl.load(pres_penalty_ptr + req_idx) + tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask) + seed = tl.load(seeds_ptr + req_idx) tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask) @@ -341,6 +433,9 @@ def expand_sampling_metadata( expanded_temp = create_empty(sampling_metadata.temperature) expanded_top_p = create_empty(sampling_metadata.top_p) expanded_top_k = create_empty(sampling_metadata.top_k) + expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty) + expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty) + expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty) expanded_seeds = create_empty(sampling_metadata.seeds) num_reqs = cu_num_logits.shape[0] - 1 @@ -351,6 +446,12 @@ def expand_sampling_metadata( expanded_top_p, sampling_metadata.top_k, expanded_top_k, + sampling_metadata.repetition_penalty, + expanded_repetition_penalty, + sampling_metadata.frequency_penalty, + expanded_frequency_penalty, + sampling_metadata.presence_penalty, + expanded_presence_penalty, sampling_metadata.seeds, expanded_seeds, cu_num_logits, @@ -361,6 +462,67 @@ def expand_sampling_metadata( top_p=expanded_top_p, top_k=expanded_top_k, seeds=expanded_seeds, + repetition_penalty=expanded_repetition_penalty, + frequency_penalty=expanded_frequency_penalty, + presence_penalty=expanded_presence_penalty, pos=sampling_metadata.pos, max_num_logprobs=sampling_metadata.max_num_logprobs, + # TODO(woosuk): Support penalties with spec decoding. + idx_mapping=sampling_metadata.idx_mapping, + prompt_bin_counts=sampling_metadata.prompt_bin_counts, + output_bin_counts=sampling_metadata.output_bin_counts, + ) + + +def use_penalty(sampling_params: SamplingParams) -> bool: + return ( + sampling_params.repetition_penalty != 1.0 + or sampling_params.frequency_penalty != 0.0 + or sampling_params.presence_penalty != 0.0 + ) + + +@triton.jit(do_not_specialize=["prefill_len", "prompt_len"]) +def _bincount_kernel( + prefill_token_ids_ptr, + prefill_len, + prompt_len, + prompt_bin_counts_ptr, + output_bin_counts_ptr, + BLOCK_SIZE: tl.constexpr, +): + block_idx = tl.program_id(0) + if block_idx * BLOCK_SIZE >= prefill_len: + return + + block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + if block_idx * BLOCK_SIZE < prompt_len: + mask = block < prompt_len + prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) + tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask) + if (block_idx + 1) * BLOCK_SIZE >= prompt_len: + mask = block < prefill_len + mask &= block >= prompt_len + prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask) + tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask) + + +def bincount( + prefill_token_ids: torch.Tensor, + prefill_len: int, + prompt_len: int, + prompt_bin_counts: torch.Tensor, + output_bin_counts: torch.Tensor, +) -> None: + prompt_bin_counts.zero_() + output_bin_counts.zero_() + BLOCK_SIZE = 1024 + num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) + _bincount_kernel[(num_blocks,)]( + prefill_token_ids, + prefill_len, + prompt_len, + prompt_bin_counts, + output_bin_counts, + BLOCK_SIZE=BLOCK_SIZE, )