From a238cbd89d07b4b0ed8fb3dff3c219a3ee3a1651 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 5 Dec 2025 21:42:47 -0800 Subject: [PATCH] [Model Runner V2] Support min-p sampling (#30171) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/sample/metadata.py | 13 +++++++ vllm/v1/worker/gpu/sample/min_p.py | 53 +++++++++++++++++++++++++++ vllm/v1/worker/gpu/sample/sampler.py | 4 ++ vllm/v1/worker/gpu/states.py | 7 ++++ 4 files changed, 77 insertions(+) create mode 100644 vllm/v1/worker/gpu/sample/min_p.py diff --git a/vllm/v1/worker/gpu/sample/metadata.py b/vllm/v1/worker/gpu/sample/metadata.py index 040771c051bb4..f10c72049cbae 100644 --- a/vllm/v1/worker/gpu/sample/metadata.py +++ b/vllm/v1/worker/gpu/sample/metadata.py @@ -13,6 +13,7 @@ class SamplingMetadata: top_p: torch.Tensor | None top_k: torch.Tensor | None + min_p: torch.Tensor | None repetition_penalty: torch.Tensor frequency_penalty: torch.Tensor @@ -44,6 +45,7 @@ class SamplingMetadata: # top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device) top_p = None top_k = None + min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device) # NOTE(woosuk): We must set penalties to their default values to make sure # the penalties kernel does not touch the placeholder bin_counts tensors. repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device) @@ -64,6 +66,7 @@ class SamplingMetadata: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, @@ -85,6 +88,8 @@ def _expand_sampling_metadata_kernel( expanded_top_p_ptr, top_k_ptr, expanded_top_k_ptr, + min_p_ptr, + expanded_min_p_ptr, rep_penalty_ptr, expanded_rep_penalty_ptr, freq_penalty_ptr, @@ -115,6 +120,10 @@ def _expand_sampling_metadata_kernel( top_k = tl.load(top_k_ptr + req_idx) tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask) + if min_p_ptr is not None: + min_p = tl.load(min_p_ptr + req_idx) + tl.store(expanded_min_p_ptr + start_idx + block, min_p, mask=mask) + rep_penalty = tl.load(rep_penalty_ptr + req_idx) tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask) @@ -138,6 +147,7 @@ def expand_sampling_metadata( expanded_temp = create_empty(sampling_metadata.temperature) expanded_top_p = create_empty(sampling_metadata.top_p) expanded_top_k = create_empty(sampling_metadata.top_k) + expanded_min_p = create_empty(sampling_metadata.min_p) expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty) expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty) expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty) @@ -151,6 +161,8 @@ def expand_sampling_metadata( expanded_top_p, sampling_metadata.top_k, expanded_top_k, + sampling_metadata.min_p, + expanded_min_p, sampling_metadata.repetition_penalty, expanded_repetition_penalty, sampling_metadata.frequency_penalty, @@ -166,6 +178,7 @@ def expand_sampling_metadata( temperature=expanded_temp, top_p=expanded_top_p, top_k=expanded_top_k, + min_p=expanded_min_p, seeds=expanded_seeds, repetition_penalty=expanded_repetition_penalty, frequency_penalty=expanded_frequency_penalty, diff --git a/vllm/v1/worker/gpu/sample/min_p.py b/vllm/v1/worker/gpu/sample/min_p.py new file mode 100644 index 0000000000000..0638818006f50 --- /dev/null +++ b/vllm/v1/worker/gpu/sample/min_p.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _min_p_kernel( + logits_ptr, + logits_stride, + min_p_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + min_p = tl.load(min_p_ptr + req_idx).to(tl.float32) + if min_p == 0.0: + return + + max_val = float("-inf") + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf") + ) + max_val = tl.max(tl.maximum(logits, max_val)) + max_val = max_val.to(tl.float32) # type: ignore + + threshold = max_val + tl.log(min_p) + for i in range(0, vocab_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < vocab_size + logits = tl.load( + logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf") + ) + logits = tl.where(logits < threshold, float("-inf"), logits) + tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask) + + +def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor | None) -> None: + if min_p is None: + return + num_reqs, vocab_size = logits.shape + BLOCK_SIZE = 1024 + _min_p_kernel[(num_reqs,)]( + logits, + logits.stride(0), + min_p, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 3429dd3e4d0fb..9a4224d8fddef 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -9,6 +9,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata +from vllm.v1.worker.gpu.sample.min_p import apply_min_p from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature @@ -61,6 +62,9 @@ class Sampler: # Apply penalties and temperature in place. apply_penalties_and_temperature(logits, sampling_metadata) + # Apply min_p in place. + apply_min_p(logits, sampling_metadata.min_p) + # Apply top_k and/or top_p. This might return a new tensor. logits = apply_top_k_top_p( logits, sampling_metadata.top_k, sampling_metadata.top_p ) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 367348c4a18f7..6823c0c8ee5c7 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -87,6 +87,7 @@ class RequestState: self.temperature = self._make_param(self.max_num_reqs, torch.float32) self.top_p = self._make_param(self.max_num_reqs, torch.float32) self.top_k = self._make_param(self.max_num_reqs, torch.int32) + self.min_p = self._make_param(self.max_num_reqs, torch.float32) self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32) self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32) self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32) @@ -162,6 +163,7 @@ class RequestState: else: top_k = self.vocab_size self.top_k.np[req_idx] = top_k + self.min_p.np[req_idx] = sampling_params.min_p self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty self.presence_penalty.np[req_idx] = sampling_params.presence_penalty @@ -217,6 +219,10 @@ class RequestState: no_top_k = np.all(top_k == self.vocab_size) top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None + min_p = self.min_p.np[idx_mapping_np] + no_min_p = np.all(min_p == 0.0) + min_p = self.min_p.copy_np_to_gpu(min_p) if not no_min_p else None + rep_penalty = self.repetition_penalty.np[idx_mapping_np] rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty) freq_penalty = self.frequency_penalty.np[idx_mapping_np] @@ -236,6 +242,7 @@ class RequestState: temperature=temperature, top_p=top_p, top_k=top_k, + min_p=min_p, repetition_penalty=rep_penalty, frequency_penalty=freq_penalty, presence_penalty=pres_penalty,