diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ac32c90d67699..66cf48bc0f5e4 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -41,8 +41,6 @@ class Sampler(nn.Module): logits = self.apply_logits_bias(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. sampled = self.sample(logits, sampling_metadata) @@ -82,9 +80,21 @@ class Sampler(nn.Module): ) -> torch.Tensor: assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) - if sampling_metadata.all_greedy: - return self.greedy_sample(logits) + if sampling_metadata.all_random: + greedy_sampled = None + else: + greedy_sampled = self.greedy_sample(logits) + if sampling_metadata.all_greedy: + return greedy_sampled + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + + # Apply min_p. + if not sampling_metadata.no_min_p: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + + # Apply top_k and/or top_p. random_sampled = self.topk_topp_sampler( logits, sampling_metadata.generators, @@ -94,13 +104,9 @@ class Sampler(nn.Module): sampling_metadata.top_p, ) - if not sampling_metadata.no_min_p: - logits = self.apply_min_p(logits, sampling_metadata.min_p) - - if sampling_metadata.all_random: + if greedy_sampled is None: return random_sampled - greedy_sampled = self.greedy_sample(logits) sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled,