Fix default length_penalty to 1.0 (#2667)

This commit is contained in:
zspo 2024-02-02 07:59:39 +08:00 committed by GitHub
parent 96b6f475dd
commit 0e163fce18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -196,7 +196,7 @@ class Sequence:
return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 0.0,
length_penalty: float = 1.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.