Use FP32 for log probabilities (#19)

This commit is contained in:
Woosuk Kwon 2023-03-31 23:33:43 -07:00 committed by GitHub
parent e3f00d191e
commit a90c97d727
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,10 +36,11 @@ class Sampler(nn.Module):
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs, out=logits)
logprobs = torch.log(probs)
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)