[Model Runner V2] Support penalties using bin counts (#29703)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-28 17:53:17 -08:00 committed by GitHub
parent ea3370b428
commit 1dcafb3dea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 280 additions and 14 deletions

View File

@ -341,6 +341,8 @@ def _post_update_kernel(
idx_mapping_ptr,
num_computed_tokens_ptr,
last_sampled_tokens_ptr,
output_bin_counts_ptr,
output_bin_counts_stride,
sampled_tokens_ptr,
sampled_tokens_stride,
num_sampled_ptr,
@ -357,6 +359,15 @@ def _post_update_kernel(
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
token_ptr = (
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + token_id
)
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
query_len = query_end - query_start
@ -374,6 +385,8 @@ def post_update(
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
last_sampled_tokens: torch.Tensor,
# [max_num_reqs, vocab_size]
output_bin_counts: torch.Tensor,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens: torch.Tensor,
# [num_reqs]
@ -388,6 +401,8 @@ def post_update(
idx_mapping,
num_computed_tokens,
last_sampled_tokens,
output_bin_counts,
output_bin_counts.stride(0),
sampled_tokens,
sampled_tokens.stride(0),
num_sampled,

View File

@ -512,7 +512,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np,
num_scheduled_tokens,
query_start_loc_np,
self.req_states.prefill_token_ids,
self.req_states.prefill_token_ids.np,
self.req_states.num_computed_prefill_tokens,
self.input_buffers.input_ids.np,
)
@ -681,7 +681,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Handle chunked prompts.
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids
prefill_token_ids = self.req_states.prefill_token_ids.np
query_start_loc = self.input_buffers.query_start_loc.np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
@ -756,6 +756,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
self.req_states.last_sampled_tokens,
self.req_states.output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
@ -785,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np = input_batch.idx_mapping_np
with async_barrier(self.spec_decode_event):
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
self.req_states.prefill_token_ids[
self.req_states.prefill_token_ids.np[
idx_mapping_np,
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
]
@ -896,7 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# barrier to avoid race conditions.
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np, pos
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
)
if input_batch.num_draft_tokens > 0:
sampling_metadata = self.req_states.expand_sampling_metadata(

View File

@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.states import SamplingMetadata
@triton.jit
def _penalties_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
idx_mapping_ptr,
prompt_bin_counts_ptr,
prompt_bin_counts_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
if not (use_rep_penalty or use_freq_penalty or use_pres_penalty):
# No penalties to apply. Early return.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
)
# Apply repetition penalties.
if use_rep_penalty:
prompt_bin_counts = tl.load(
prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block,
mask=mask,
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scale = tl.where(logits > 0, 1.0 / scale, scale)
logits *= scale
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * (output_bin_counts > 0)
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_counts,
sampling_metadata.prompt_bin_counts.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@ -8,6 +8,7 @@ from vllm.config.model import LogprobsMode
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.penalties import apply_penalties
from vllm.v1.worker.gpu.states import SamplingMetadata
@ -65,6 +66,8 @@ class Sampler:
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
# Apply penalties in place.
apply_penalties(logits, sampling_metadata)
sampled = gumbel_sample(
logits,

View File

@ -8,6 +8,8 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.utils import CpuGpuBuffer
@ -23,12 +25,21 @@ class SamplingMetadata:
top_p: torch.Tensor | None
top_k: torch.Tensor | None
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
# For penalties
idx_mapping: torch.Tensor
prompt_bin_counts: torch.Tensor
output_bin_counts: torch.Tensor
@classmethod
def make_dummy(
cls,
@ -44,17 +55,35 @@ class SamplingMetadata:
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p = None
top_k = None
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_counts=prompt_bin_counts,
output_bin_counts=output_bin_counts,
)
@ -83,9 +112,10 @@ class RequestState:
self.extra_data: dict[str, ExtraData] = {}
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
self.prefill_token_ids = np.zeros(
(self.max_num_reqs, self.max_model_len),
dtype=np.int32,
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
self.prefill_token_ids = UvaBuffer(
self.max_num_reqs, self.max_model_len, dtype=torch.int32
)
self.prefill_len = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
@ -119,6 +149,9 @@ class RequestState:
self.temperature = self._make_param(self.max_num_reqs, torch.float32)
self.top_p = self._make_param(self.max_num_reqs, torch.float32)
self.top_k = self._make_param(self.max_num_reqs, torch.int32)
self.repetition_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.frequency_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.presence_penalty = self._make_param(self.max_num_reqs, torch.float32)
self.seeds = self._make_param(self.max_num_reqs, torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
@ -126,6 +159,16 @@ class RequestState:
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# Statistics for penalties.
# TODO(woosuk): These tensors are rarely used but can be extremely large.
# Optimize the memory usage.
self.prompt_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
def _make_param(self, size: int, dtype: torch.dtype) -> "Param":
return Param(size, dtype=dtype, device=self.device, pin_memory=self.pin_memory)
@ -159,7 +202,7 @@ class RequestState:
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids[req_idx, :prefill_len] = prefill_token_ids
self.prefill_token_ids.np[req_idx, :prefill_len] = prefill_token_ids
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
# FIXME(woosuk): This triggers a GPU operation whenever adding a new request.
@ -178,6 +221,18 @@ class RequestState:
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
bincount(
self.prefill_token_ids.gpu[req_idx],
prefill_len,
prompt_len,
self.prompt_bin_counts[req_idx],
self.output_bin_counts[req_idx],
)
if sampling_params.seed is not None:
seed = sampling_params.seed
@ -206,24 +261,32 @@ class RequestState:
def make_sampling_metadata(
self,
idx_mapping: np.ndarray,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.np[idx_mapping]
temperature = self.temperature.np[idx_mapping_np]
temperature = self.temperature.copy_np_to_gpu(temperature)
top_p = self.top_p.np[idx_mapping]
top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_np_to_gpu(top_p) if not no_top_p else None
top_k = self.top_k.np[idx_mapping]
top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_np_to_gpu(top_k) if not no_top_k else None
seeds = self.seeds.np[idx_mapping]
rep_penalty = self.repetition_penalty.np[idx_mapping_np]
rep_penalty = self.repetition_penalty.copy_np_to_gpu(rep_penalty)
freq_penalty = self.frequency_penalty.np[idx_mapping_np]
freq_penalty = self.frequency_penalty.copy_np_to_gpu(freq_penalty)
pres_penalty = self.presence_penalty.np[idx_mapping_np]
pres_penalty = self.presence_penalty.copy_np_to_gpu(pres_penalty)
seeds = self.seeds.np[idx_mapping_np]
seeds = self.seeds.copy_np_to_gpu(seeds)
num_logprobs = self.num_logprobs[idx_mapping]
num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
@ -232,9 +295,15 @@ class RequestState:
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty,
presence_penalty=pres_penalty,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
idx_mapping=idx_mapping,
prompt_bin_counts=self.prompt_bin_counts,
output_bin_counts=self.output_bin_counts,
)
def expand_sampling_metadata(
@ -294,6 +363,14 @@ class ExtraData:
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
class UvaBuffer:
def __init__(self, *size: int | torch.SymInt, dtype: torch.dtype):
assert is_uva_available()
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
@triton.jit
def _expand_sampling_metadata_kernel(
@ -304,6 +381,12 @@ def _expand_sampling_metadata_kernel(
top_k_ptr,
expanded_top_k_ptr,
seeds_ptr,
rep_penalty_ptr,
expanded_rep_penalty_ptr,
freq_penalty_ptr,
expanded_freq_penalty_ptr,
pres_penalty_ptr,
expanded_pres_penalty_ptr,
expanded_seeds_ptr,
cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr,
@ -327,6 +410,15 @@ def _expand_sampling_metadata_kernel(
top_k = tl.load(top_k_ptr + req_idx)
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
rep_penalty = tl.load(rep_penalty_ptr + req_idx)
tl.store(expanded_rep_penalty_ptr + start_idx + block, rep_penalty, mask=mask)
freq_penalty = tl.load(freq_penalty_ptr + req_idx)
tl.store(expanded_freq_penalty_ptr + start_idx + block, freq_penalty, mask=mask)
pres_penalty = tl.load(pres_penalty_ptr + req_idx)
tl.store(expanded_pres_penalty_ptr + start_idx + block, pres_penalty, mask=mask)
seed = tl.load(seeds_ptr + req_idx)
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
@ -341,6 +433,9 @@ def expand_sampling_metadata(
expanded_temp = create_empty(sampling_metadata.temperature)
expanded_top_p = create_empty(sampling_metadata.top_p)
expanded_top_k = create_empty(sampling_metadata.top_k)
expanded_repetition_penalty = create_empty(sampling_metadata.repetition_penalty)
expanded_frequency_penalty = create_empty(sampling_metadata.frequency_penalty)
expanded_presence_penalty = create_empty(sampling_metadata.presence_penalty)
expanded_seeds = create_empty(sampling_metadata.seeds)
num_reqs = cu_num_logits.shape[0] - 1
@ -351,6 +446,12 @@ def expand_sampling_metadata(
expanded_top_p,
sampling_metadata.top_k,
expanded_top_k,
sampling_metadata.repetition_penalty,
expanded_repetition_penalty,
sampling_metadata.frequency_penalty,
expanded_frequency_penalty,
sampling_metadata.presence_penalty,
expanded_presence_penalty,
sampling_metadata.seeds,
expanded_seeds,
cu_num_logits,
@ -361,6 +462,67 @@ def expand_sampling_metadata(
top_p=expanded_top_p,
top_k=expanded_top_k,
seeds=expanded_seeds,
repetition_penalty=expanded_repetition_penalty,
frequency_penalty=expanded_frequency_penalty,
presence_penalty=expanded_presence_penalty,
pos=sampling_metadata.pos,
max_num_logprobs=sampling_metadata.max_num_logprobs,
# TODO(woosuk): Support penalties with spec decoding.
idx_mapping=sampling_metadata.idx_mapping,
prompt_bin_counts=sampling_metadata.prompt_bin_counts,
output_bin_counts=sampling_metadata.output_bin_counts,
)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
def _bincount_kernel(
prefill_token_ids_ptr,
prefill_len,
prompt_len,
prompt_bin_counts_ptr,
output_bin_counts_ptr,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
if block_idx * BLOCK_SIZE >= prefill_len:
return
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
def bincount(
prefill_token_ids: torch.Tensor,
prefill_len: int,
prompt_len: int,
prompt_bin_counts: torch.Tensor,
output_bin_counts: torch.Tensor,
) -> None:
prompt_bin_counts.zero_()
output_bin_counts.zero_()
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_blocks,)](
prefill_token_ids,
prefill_len,
prompt_len,
prompt_bin_counts,
output_bin_counts,
BLOCK_SIZE=BLOCK_SIZE,
)