mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 20:57:57 +08:00
[Bugfix] fix apply_temperature to avoid nan in probs (#24734)
Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d8ffa3c5f4
commit
94b78f576c
@ -128,8 +128,12 @@ class Sampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
temp: torch.Tensor,
|
temp: torch.Tensor,
|
||||||
|
all_random: bool,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Use in-place division to avoid creating a new tensor.
|
# Use in-place division to avoid creating a new tensor.
|
||||||
|
# Avoid division by zero if there are greedy requests.
|
||||||
|
if not all_random:
|
||||||
|
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
|
||||||
return logits.div_(temp.unsqueeze(dim=1))
|
return logits.div_(temp.unsqueeze(dim=1))
|
||||||
|
|
||||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
@ -164,7 +168,8 @@ class Sampler(nn.Module):
|
|||||||
assert sampling_metadata.temperature is not None
|
assert sampling_metadata.temperature is not None
|
||||||
|
|
||||||
# Apply temperature.
|
# Apply temperature.
|
||||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
logits = self.apply_temperature(logits, sampling_metadata.temperature,
|
||||||
|
sampling_metadata.all_random)
|
||||||
|
|
||||||
# Apply logits processors that only apply to random sampling
|
# Apply logits processors that only apply to random sampling
|
||||||
# (argmax invariant)
|
# (argmax invariant)
|
||||||
|
|||||||
@ -354,8 +354,8 @@ class InputBatch:
|
|||||||
and is_spec_decode_unsupported(sampling_params)):
|
and is_spec_decode_unsupported(sampling_params)):
|
||||||
self.spec_decode_unsupported_reqs.add(req_id)
|
self.spec_decode_unsupported_reqs.add(req_id)
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
# Avoid later division by zero.
|
# Should avoid division by zero later when apply_temperature.
|
||||||
self.temperature_cpu[req_index] = -1.0
|
self.temperature_cpu[req_index] = 0.0
|
||||||
self.greedy_reqs.add(req_id)
|
self.greedy_reqs.add(req_id)
|
||||||
else:
|
else:
|
||||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user