mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 19:15:57 +08:00
[Model Runner V2] Fuse penalties and temperature into single kernel (#29720)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
04a797cd0e
commit
f223ed4181
@ -45,8 +45,9 @@ def _gumbel_sample_kernel(
|
|||||||
|
|
||||||
# Apply temperature.
|
# Apply temperature.
|
||||||
if APPLY_TEMPERATURE:
|
if APPLY_TEMPERATURE:
|
||||||
# NOTE(woosuk): Use div_rn to match the behavior of torch.
|
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel.
|
||||||
logits = tl.div_rn(logits, temp)
|
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||||
|
logits = logits / temp
|
||||||
|
|
||||||
# Apply gumbel noise.
|
# Apply gumbel noise.
|
||||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||||
|
|||||||
@ -7,12 +7,13 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
|||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _penalties_kernel(
|
def _penalties_and_temperature_kernel(
|
||||||
logits_ptr,
|
logits_ptr,
|
||||||
logits_stride,
|
logits_stride,
|
||||||
repetition_penalty_ptr,
|
repetition_penalty_ptr,
|
||||||
frequency_penalty_ptr,
|
frequency_penalty_ptr,
|
||||||
presence_penalty_ptr,
|
presence_penalty_ptr,
|
||||||
|
temperature_ptr,
|
||||||
idx_mapping_ptr,
|
idx_mapping_ptr,
|
||||||
prompt_bin_counts_ptr,
|
prompt_bin_counts_ptr,
|
||||||
prompt_bin_counts_stride,
|
prompt_bin_counts_stride,
|
||||||
@ -25,12 +26,16 @@ def _penalties_kernel(
|
|||||||
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
|
||||||
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
|
||||||
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
|
||||||
|
temperature = tl.load(temperature_ptr + batch_idx)
|
||||||
|
temperature = tl.where(temperature == 0.0, 1.0, temperature)
|
||||||
|
|
||||||
use_rep_penalty = rep_penalty != 1.0
|
use_rep_penalty = rep_penalty != 1.0
|
||||||
use_freq_penalty = freq_penalty != 0.0
|
use_freq_penalty = freq_penalty != 0.0
|
||||||
use_pres_penalty = pres_penalty != 0.0
|
use_pres_penalty = pres_penalty != 0.0
|
||||||
if not (use_rep_penalty or use_freq_penalty or use_pres_penalty):
|
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
|
||||||
# No penalties to apply. Early return.
|
use_temperature = temperature != 1.0
|
||||||
|
if not (use_penalty or use_temperature):
|
||||||
|
# Early return to avoid loading logits.
|
||||||
return
|
return
|
||||||
|
|
||||||
block_idx = tl.program_id(1)
|
block_idx = tl.program_id(1)
|
||||||
@ -39,42 +44,54 @@ def _penalties_kernel(
|
|||||||
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
|
||||||
logits = logits.to(tl.float32)
|
logits = logits.to(tl.float32)
|
||||||
|
|
||||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
if use_penalty:
|
||||||
output_bin_counts = tl.load(
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||||
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
output_bin_counts = tl.load(
|
||||||
mask=mask,
|
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
|
||||||
)
|
|
||||||
|
|
||||||
# Apply repetition penalties.
|
|
||||||
if use_rep_penalty:
|
|
||||||
prompt_bin_counts = tl.load(
|
|
||||||
prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block,
|
|
||||||
mask=mask,
|
mask=mask,
|
||||||
)
|
)
|
||||||
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
output_bin_mask = output_bin_counts > 0
|
||||||
scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0)
|
|
||||||
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
# Apply repetition penalties.
|
||||||
scale = tl.where(logits > 0, 1.0 / scale, scale)
|
if use_rep_penalty:
|
||||||
logits *= scale
|
prompt_bin_counts = tl.load(
|
||||||
|
prompt_bin_counts_ptr
|
||||||
|
+ req_state_idx * prompt_bin_counts_stride
|
||||||
|
+ block,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
prompt_bin_mask = prompt_bin_counts > 0
|
||||||
|
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
||||||
|
scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
|
||||||
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
||||||
|
logits *= tl.where(logits > 0, 1.0 / scale, scale)
|
||||||
|
|
||||||
|
# Apply frequency penalties.
|
||||||
|
logits -= freq_penalty * output_bin_counts
|
||||||
|
# Apply presence penalties.
|
||||||
|
logits -= pres_penalty * output_bin_mask
|
||||||
|
|
||||||
|
# Apply temperature.
|
||||||
|
logits = logits / temperature
|
||||||
|
|
||||||
# Apply frequency penalties.
|
|
||||||
logits -= freq_penalty * output_bin_counts
|
|
||||||
# Apply presence penalties.
|
|
||||||
logits -= pres_penalty * (output_bin_counts > 0)
|
|
||||||
# Store back to logits.
|
# Store back to logits.
|
||||||
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None:
|
def apply_penalties_and_temperature(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> None:
|
||||||
num_reqs, vocab_size = logits.shape
|
num_reqs, vocab_size = logits.shape
|
||||||
BLOCK_SIZE = 8192
|
BLOCK_SIZE = 8192
|
||||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||||
_penalties_kernel[(num_reqs, num_blocks)](
|
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
|
||||||
logits,
|
logits,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
sampling_metadata.repetition_penalty,
|
sampling_metadata.repetition_penalty,
|
||||||
sampling_metadata.frequency_penalty,
|
sampling_metadata.frequency_penalty,
|
||||||
sampling_metadata.presence_penalty,
|
sampling_metadata.presence_penalty,
|
||||||
|
sampling_metadata.temperature,
|
||||||
sampling_metadata.idx_mapping,
|
sampling_metadata.idx_mapping,
|
||||||
sampling_metadata.prompt_bin_counts,
|
sampling_metadata.prompt_bin_counts,
|
||||||
sampling_metadata.prompt_bin_counts.stride(0),
|
sampling_metadata.prompt_bin_counts.stride(0),
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
|||||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||||
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.gpu.sample.penalties import apply_penalties
|
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
|
||||||
|
|
||||||
|
|
||||||
class Sampler:
|
class Sampler:
|
||||||
@ -26,22 +26,19 @@ class Sampler:
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
|
sampled, processed_logits = self.sample(logits, sampling_metadata)
|
||||||
if sampling_metadata.max_num_logprobs is not None:
|
if sampling_metadata.max_num_logprobs is not None:
|
||||||
if self.logprobs_mode == "processed_logprobs":
|
logits = (
|
||||||
sampled, logits = self.sample(
|
processed_logits
|
||||||
logits, sampling_metadata, return_logits=True
|
if self.logprobs_mode == "processed_logprobs"
|
||||||
)
|
else logits
|
||||||
else:
|
)
|
||||||
assert self.logprobs_mode == "raw_logprobs"
|
|
||||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
|
||||||
|
|
||||||
logprobs_tensors = compute_topk_logprobs(
|
logprobs_tensors = compute_topk_logprobs(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.max_num_logprobs,
|
sampling_metadata.max_num_logprobs,
|
||||||
sampled,
|
sampled,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sampled, _ = self.sample(logits, sampling_metadata, return_logits=False)
|
|
||||||
logprobs_tensors = None
|
logprobs_tensors = None
|
||||||
|
|
||||||
# These are GPU tensors.
|
# These are GPU tensors.
|
||||||
@ -58,16 +55,15 @@ class Sampler:
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
return_logits: bool = False,
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
# Copy logits to a new FP32 tensor.
|
||||||
is_greedy = sampling_metadata.temperature == 0
|
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||||
temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
|
|
||||||
logits = logits / temp.view(-1, 1)
|
# Apply penalties and temperature in place.
|
||||||
|
apply_penalties_and_temperature(logits, sampling_metadata)
|
||||||
logits = apply_top_k_top_p(
|
logits = apply_top_k_top_p(
|
||||||
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
logits, sampling_metadata.top_k, sampling_metadata.top_p
|
||||||
)
|
)
|
||||||
# Apply penalties in place.
|
|
||||||
apply_penalties(logits, sampling_metadata)
|
|
||||||
|
|
||||||
sampled = gumbel_sample(
|
sampled = gumbel_sample(
|
||||||
logits,
|
logits,
|
||||||
@ -76,4 +72,4 @@ class Sampler:
|
|||||||
sampling_metadata.pos,
|
sampling_metadata.pos,
|
||||||
apply_temperature=False,
|
apply_temperature=False,
|
||||||
)
|
)
|
||||||
return sampled, logits if return_logits else None
|
return sampled, logits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user