diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index 3e0d72e56939b..a95bf1e7a37a7 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -45,8 +45,9 @@ def _gumbel_sample_kernel( # Apply temperature. if APPLY_TEMPERATURE: - # NOTE(woosuk): Use div_rn to match the behavior of torch. - logits = tl.div_rn(logits, temp) + # NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel. + # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. + logits = logits / temp # Apply gumbel noise. logits = tl.where(mask, logits + gumbel_noise, float("-inf")) diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 1607a75fd56ba..69cf9d26ec992 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -7,12 +7,13 @@ from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata @triton.jit -def _penalties_kernel( +def _penalties_and_temperature_kernel( logits_ptr, logits_stride, repetition_penalty_ptr, frequency_penalty_ptr, presence_penalty_ptr, + temperature_ptr, idx_mapping_ptr, prompt_bin_counts_ptr, prompt_bin_counts_stride, @@ -25,12 +26,16 @@ def _penalties_kernel( rep_penalty = tl.load(repetition_penalty_ptr + batch_idx) freq_penalty = tl.load(frequency_penalty_ptr + batch_idx) pres_penalty = tl.load(presence_penalty_ptr + batch_idx) + temperature = tl.load(temperature_ptr + batch_idx) + temperature = tl.where(temperature == 0.0, 1.0, temperature) use_rep_penalty = rep_penalty != 1.0 use_freq_penalty = freq_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0 - if not (use_rep_penalty or use_freq_penalty or use_pres_penalty): - # No penalties to apply. Early return. + use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty + use_temperature = temperature != 1.0 + if not (use_penalty or use_temperature): + # Early return to avoid loading logits. return block_idx = tl.program_id(1) @@ -39,42 +44,54 @@ def _penalties_kernel( logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) logits = logits.to(tl.float32) - req_state_idx = tl.load(idx_mapping_ptr + batch_idx) - output_bin_counts = tl.load( - output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, - mask=mask, - ) - - # Apply repetition penalties. - if use_rep_penalty: - prompt_bin_counts = tl.load( - prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block, + if use_penalty: + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + output_bin_counts = tl.load( + output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, mask=mask, ) - # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0) - # If logits are positive, divide by penalty, otherwise multiply by penalty. - scale = tl.where(logits > 0, 1.0 / scale, scale) - logits *= scale + output_bin_mask = output_bin_counts > 0 + + # Apply repetition penalties. + if use_rep_penalty: + prompt_bin_counts = tl.load( + prompt_bin_counts_ptr + + req_state_idx * prompt_bin_counts_stride + + block, + mask=mask, + ) + prompt_bin_mask = prompt_bin_counts > 0 + # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. + scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0) + # If logits are positive, divide by penalty, otherwise multiply by penalty. + logits *= tl.where(logits > 0, 1.0 / scale, scale) + + # Apply frequency penalties. + logits -= freq_penalty * output_bin_counts + # Apply presence penalties. + logits -= pres_penalty * output_bin_mask + + # Apply temperature. + logits = logits / temperature - # Apply frequency penalties. - logits -= freq_penalty * output_bin_counts - # Apply presence penalties. - logits -= pres_penalty * (output_bin_counts > 0) # Store back to logits. tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) -def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None: +def apply_penalties_and_temperature( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> None: num_reqs, vocab_size = logits.shape BLOCK_SIZE = 8192 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) - _penalties_kernel[(num_reqs, num_blocks)]( + _penalties_and_temperature_kernel[(num_reqs, num_blocks)]( logits, logits.stride(0), sampling_metadata.repetition_penalty, sampling_metadata.frequency_penalty, sampling_metadata.presence_penalty, + sampling_metadata.temperature, sampling_metadata.idx_mapping, sampling_metadata.prompt_bin_counts, sampling_metadata.prompt_bin_counts.stride(0), diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index 4e7c85a021cfb..3429dd3e4d0fb 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -9,7 +9,7 @@ from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu.sample.penalties import apply_penalties +from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature class Sampler: @@ -26,22 +26,19 @@ class Sampler: logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: + sampled, processed_logits = self.sample(logits, sampling_metadata) if sampling_metadata.max_num_logprobs is not None: - if self.logprobs_mode == "processed_logprobs": - sampled, logits = self.sample( - logits, sampling_metadata, return_logits=True - ) - else: - assert self.logprobs_mode == "raw_logprobs" - sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) - + logits = ( + processed_logits + if self.logprobs_mode == "processed_logprobs" + else logits + ) logprobs_tensors = compute_topk_logprobs( logits, sampling_metadata.max_num_logprobs, sampled, ) else: - sampled, _ = self.sample(logits, sampling_metadata, return_logits=False) logprobs_tensors = None # These are GPU tensors. @@ -58,16 +55,15 @@ class Sampler: self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, - return_logits: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - is_greedy = sampling_metadata.temperature == 0 - temp = torch.where(is_greedy, 1.0, sampling_metadata.temperature) - logits = logits / temp.view(-1, 1) + ) -> tuple[torch.Tensor, torch.Tensor]: + # Copy logits to a new FP32 tensor. + logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) + + # Apply penalties and temperature in place. + apply_penalties_and_temperature(logits, sampling_metadata) logits = apply_top_k_top_p( logits, sampling_metadata.top_k, sampling_metadata.top_p ) - # Apply penalties in place. - apply_penalties(logits, sampling_metadata) sampled = gumbel_sample( logits, @@ -76,4 +72,4 @@ class Sampler: sampling_metadata.pos, apply_temperature=False, ) - return sampled, logits if return_logits else None + return sampled, logits