From 6a7c7711a2588ca4a5e713e5335122988f8c0a55 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 5 Jun 2024 15:19:02 -0700 Subject: [PATCH] [Misc] Skip for logits_scale == 1.0 (#5291) --- vllm/model_executor/layers/logits_processor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index d450c46455d49..7eee599473a11 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -21,7 +21,7 @@ class LogitsProcessor(nn.Module): def __init__(self, vocab_size: int, org_vocab_size: Optional[int] = None, - scale: Optional[float] = 1.0, + scale: float = 1.0, logits_as_input: bool = False) -> None: """ Args: @@ -52,7 +52,8 @@ class LogitsProcessor(nn.Module): logits = self._get_logits(hidden_states, embedding, embedding_bias) if logits is not None: - logits *= self.scale + if self.scale != 1.0: + logits *= self.scale # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata)