[Model Runner V2] Fuse penalties and temperature into single kernel (#29720)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-29 02:29:16 -08:00 committed by GitHub
parent 04a797cd0e
commit f223ed4181
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 58 additions and 44 deletions

View File

@ -45,8 +45,9 @@ def _gumbel_sample_kernel(
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Use div_rn to match the behavior of torch.
logits = tl.div_rn(logits, temp)
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))

View File

@ -7,12 +7,13 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
@triton.jit
def _penalties_kernel(
def _penalties_and_temperature_kernel(
logits_ptr,
logits_stride,
repetition_penalty_ptr,
frequency_penalty_ptr,
presence_penalty_ptr,
temperature_ptr,
idx_mapping_ptr,
prompt_bin_counts_ptr,
prompt_bin_counts_stride,
@ -25,12 +26,16 @@ def _penalties_kernel(
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
temperature = tl.load(temperature_ptr + batch_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0
if not (use_rep_penalty or use_freq_penalty or use_pres_penalty):
# No penalties to apply. Early return.
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
use_temperature = temperature != 1.0
if not (use_penalty or use_temperature):
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
@ -39,42 +44,54 @@ def _penalties_kernel(
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
)
# Apply repetition penalties.
if use_rep_penalty:
prompt_bin_counts = tl.load(
prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block,
if use_penalty:
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask,
)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scale = tl.where(logits > 0, 1.0 / scale, scale)
logits *= scale
output_bin_mask = output_bin_counts > 0
# Apply repetition penalties.
if use_rep_penalty:
prompt_bin_counts = tl.load(
prompt_bin_counts_ptr
+ req_state_idx * prompt_bin_counts_stride
+ block,
mask=mask,
)
prompt_bin_mask = prompt_bin_counts > 0
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
logits *= tl.where(logits > 0, 1.0 / scale, scale)
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Apply frequency penalties.
logits -= freq_penalty * output_bin_counts
# Apply presence penalties.
logits -= pres_penalty * (output_bin_counts > 0)
# Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None:
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_kernel[(num_reqs, num_blocks)](
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.idx_mapping,
sampling_metadata.prompt_bin_counts,
sampling_metadata.prompt_bin_counts.stride(0),

View File

@ -9,7 +9,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import apply_penalties
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
class Sampler:
@ -26,22 +26,19 @@ class Sampler:
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
sampled, processed_logits = self.sample(logits, sampling_metadata)
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logprobs":
sampled, logits = self.sample(
logits, sampling_metadata, return_logits=True
)
else:
assert self.logprobs_mode == "raw_logprobs"
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
logits = (
processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else:
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
logprobs_tensors = None
# These are GPU tensors.
@ -58,16 +55,15 @@ class Sampler:
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
return_logits: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
is_greedy = sampling_metadata.temperature == 0
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
logits = logits / temp.view(-1, 1)
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
)
# Apply penalties in place.
apply_penalties(logits, sampling_metadata)
sampled = gumbel_sample(
logits,
@ -76,4 +72,4 @@ class Sampler:
sampling_metadata.pos,
apply_temperature=False,
)
return sampled, logits if return_logits else None
return sampled, logits