[Misc] Skip for logits_scale == 1.0 (#5291)

This commit is contained in:
Woosuk Kwon 2024-06-05 15:19:02 -07:00 committed by GitHub
parent 0f83ddd4d7
commit 6a7c7711a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,7 +21,7 @@ class LogitsProcessor(nn.Module):
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0,
scale: float = 1.0,
logits_as_input: bool = False) -> None:
"""
Args:
@ -52,7 +52,8 @@ class LogitsProcessor(nn.Module):
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
if self.scale != 1.0:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)