From 6a854c7a2bb5b8a2015bbd83d94d311b991ac45d Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 14 Feb 2025 18:10:53 -0800 Subject: [PATCH] [V1][Sampler] Don't apply temp for greedy-only (#13311) Signed-off-by: Nick Hill --- vllm/v1/sample/sampler.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index ac32c90d67699..66cf48bc0f5e4 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -41,8 +41,6 @@ class Sampler(nn.Module): logits = self.apply_logits_bias(logits, sampling_metadata) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) - # Apply temperature. - logits = self.apply_temperature(logits, sampling_metadata.temperature) # Sample the next token. sampled = self.sample(logits, sampling_metadata) @@ -82,9 +80,21 @@ class Sampler(nn.Module): ) -> torch.Tensor: assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) - if sampling_metadata.all_greedy: - return self.greedy_sample(logits) + if sampling_metadata.all_random: + greedy_sampled = None + else: + greedy_sampled = self.greedy_sample(logits) + if sampling_metadata.all_greedy: + return greedy_sampled + # Apply temperature. + logits = self.apply_temperature(logits, sampling_metadata.temperature) + + # Apply min_p. + if not sampling_metadata.no_min_p: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + + # Apply top_k and/or top_p. random_sampled = self.topk_topp_sampler( logits, sampling_metadata.generators, @@ -94,13 +104,9 @@ class Sampler(nn.Module): sampling_metadata.top_p, ) - if not sampling_metadata.no_min_p: - logits = self.apply_min_p(logits, sampling_metadata.min_p) - - if sampling_metadata.all_random: + if greedy_sampled is None: return random_sampled - greedy_sampled = self.greedy_sample(logits) sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled,