mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[FIX] Minor bug fixes (#1035)
* [FIX] Minor bug fixes * Address review comments
This commit is contained in:
parent
ab019eea75
commit
f04908cae7
@ -82,8 +82,9 @@ class Sampler(nn.Module):
|
|||||||
# We use float32 for probabilities and log probabilities.
|
# We use float32 for probabilities and log probabilities.
|
||||||
# Compute the probabilities.
|
# Compute the probabilities.
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# Compute the log probabilities (before applying top-p and top-k).
|
# Compute the log probabilities.
|
||||||
logprobs = torch.log(probs)
|
# Use log_softmax to ensure numerical stability.
|
||||||
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
return _sample(probs, logprobs, input_metadata)
|
return _sample(probs, logprobs, input_metadata)
|
||||||
|
|||||||
@ -350,7 +350,7 @@ class SequenceOutputs:
|
|||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, SequenceOutputs):
|
if not isinstance(other, SequenceOutputs):
|
||||||
return NotImplementedError()
|
raise NotImplementedError()
|
||||||
return (self.parent_seq_id == other.parent_seq_id
|
return (self.parent_seq_id == other.parent_seq_id
|
||||||
and self.output_token == other.output_token
|
and self.output_token == other.output_token
|
||||||
and self.logprobs == other.logprobs)
|
and self.logprobs == other.logprobs)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user