diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 87eeccb61634b..3b53f34f4c732 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -36,10 +36,11 @@ class Sampler(nn.Module): # Use in-place division to avoid creating a new tensor. logits.div_(t.unsqueeze(dim=1)) + # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) # Compute the log probabilities (before applying top-p). - logprobs = torch.log(probs, out=logits) + logprobs = torch.log(probs) # Apply top-p truncation. top_ps = _get_top_ps(input_metadata)