diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index 499e9d3b1538d..c48ed2d8ca167 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -68,9 +68,10 @@ class Sampler: sampled = gumbel_sample( logits, - is_greedy, + sampling_metadata.temperature, sampling_metadata.seeds, sampling_metadata.pos, + apply_temperature=False, ) return sampled, logits if return_logits else None @@ -85,9 +86,10 @@ def _gumbel_sample_kernel( logits_stride, seeds_ptr, pos_ptr, - is_greedy_ptr, + temp_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, + APPLY_TEMPERATURE: tl.constexpr, ): req_idx = tl.program_id(0) block_idx = tl.program_id(1) @@ -99,8 +101,8 @@ def _gumbel_sample_kernel( other=float("-inf"), ) - is_greedy = tl.load(is_greedy_ptr + req_idx) - if not is_greedy: + temp = tl.load(temp_ptr + req_idx) + if temp != 0.0: # Calculate the seed for gumbel noise. seed = tl.load(seeds_ptr + req_idx) pos = tl.load(pos_ptr + req_idx) @@ -111,6 +113,11 @@ def _gumbel_sample_kernel( gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) gumbel_noise = gumbel_noise.to(tl.float32) + # Apply temperature. + if APPLY_TEMPERATURE: + # NOTE(woosuk): Use div_rn to match the behavior of torch. + logits = tl.div_rn(logits, temp.to(tl.float32)) + # Apply gumbel noise. logits = tl.where(mask, logits + gumbel_noise, float("-inf")) @@ -123,9 +130,10 @@ def _gumbel_sample_kernel( def gumbel_sample( logits: torch.Tensor, # [num_reqs, vocab_size] - is_greedy: torch.Tensor, # [num_reqs] + temperature: torch.Tensor, # [num_reqs] seed: torch.Tensor, # [num_reqs] pos: torch.Tensor, # [num_reqs] + apply_temperature: bool, ) -> torch.Tensor: num_reqs, vocab_size = logits.shape BLOCK_SIZE = 1024 @@ -151,9 +159,10 @@ def gumbel_sample( logits.stride(0), seed, pos, - is_greedy, + temperature, vocab_size, BLOCK_SIZE=BLOCK_SIZE, + APPLY_TEMPERATURE=apply_temperature, ) # NOTE(woosuk): Use int64 for later indexing. max_block_idx = local_max.argmax(dim=-1, keepdim=True)