mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:54:55 +08:00
[FIX] Minor bug fixes (#1035)
* [FIX] Minor bug fixes * Address review comments
This commit is contained in:
parent
ab019eea75
commit
f04908cae7
@ -82,8 +82,9 @@ class Sampler(nn.Module):
|
||||
# We use float32 for probabilities and log probabilities.
|
||||
# Compute the probabilities.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# Compute the log probabilities (before applying top-p and top-k).
|
||||
logprobs = torch.log(probs)
|
||||
# Compute the log probabilities.
|
||||
# Use log_softmax to ensure numerical stability.
|
||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
return _sample(probs, logprobs, input_metadata)
|
||||
|
||||
@ -350,7 +350,7 @@ class SequenceOutputs:
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplementedError()
|
||||
raise NotImplementedError()
|
||||
return (self.parent_seq_id == other.parent_seq_id
|
||||
and self.output_token == other.output_token
|
||||
and self.logprobs == other.logprobs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user