mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-07 08:32:19 +08:00
[V1][Sampler] Don't apply temp for greedy-only (#13311)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
e7eea5a520
commit
6a854c7a2b
@ -41,8 +41,6 @@ class Sampler(nn.Module):
|
|||||||
logits = self.apply_logits_bias(logits, sampling_metadata)
|
logits = self.apply_logits_bias(logits, sampling_metadata)
|
||||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||||
logits = self.apply_penalties(logits, sampling_metadata)
|
logits = self.apply_penalties(logits, sampling_metadata)
|
||||||
# Apply temperature.
|
|
||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
sampled = self.sample(logits, sampling_metadata)
|
sampled = self.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
@ -82,9 +80,21 @@ class Sampler(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert not (sampling_metadata.all_greedy
|
assert not (sampling_metadata.all_greedy
|
||||||
and sampling_metadata.all_random)
|
and sampling_metadata.all_random)
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_random:
|
||||||
return self.greedy_sample(logits)
|
greedy_sampled = None
|
||||||
|
else:
|
||||||
|
greedy_sampled = self.greedy_sample(logits)
|
||||||
|
if sampling_metadata.all_greedy:
|
||||||
|
return greedy_sampled
|
||||||
|
|
||||||
|
# Apply temperature.
|
||||||
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||||
|
|
||||||
|
# Apply min_p.
|
||||||
|
if not sampling_metadata.no_min_p:
|
||||||
|
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||||
|
|
||||||
|
# Apply top_k and/or top_p.
|
||||||
random_sampled = self.topk_topp_sampler(
|
random_sampled = self.topk_topp_sampler(
|
||||||
logits,
|
logits,
|
||||||
sampling_metadata.generators,
|
sampling_metadata.generators,
|
||||||
@ -94,13 +104,9 @@ class Sampler(nn.Module):
|
|||||||
sampling_metadata.top_p,
|
sampling_metadata.top_p,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not sampling_metadata.no_min_p:
|
if greedy_sampled is None:
|
||||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
|
||||||
|
|
||||||
if sampling_metadata.all_random:
|
|
||||||
return random_sampled
|
return random_sampled
|
||||||
|
|
||||||
greedy_sampled = self.greedy_sample(logits)
|
|
||||||
sampled = torch.where(
|
sampled = torch.where(
|
||||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
greedy_sampled,
|
greedy_sampled,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user