[Misc] Change dummy profiling and BOS fallback warns to log once (#8820)

This commit is contained in:
Michael Goin 2024-09-26 19:18:14 -04:00 committed by GitHub
parent 93d364da34
commit b28d2104de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 10 deletions

View File

@ -8,6 +8,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
@ -71,20 +72,21 @@ class InputPreprocessor:
'''
if not self.is_encoder_decoder_model():
logger.warning("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
print_warning_once("Using None for decoder start token id because "
"this is not an encoder/decoder model.")
return None
if (self.model_config is None or self.model_config.hf_config is None):
logger.warning("Using None for decoder start token id because "
"model config is not available.")
print_warning_once("Using None for decoder start token id because "
"model config is not available.")
return None
dec_start_token_id = getattr(self.model_config.hf_config,
'decoder_start_token_id', None)
if dec_start_token_id is None:
logger.warning("Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available.")
print_warning_once("Falling back on <BOS> for decoder start token "
"id because decoder start token id is not "
"available.")
dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id

View File

@ -9,7 +9,7 @@ from transformers import PretrainedConfig
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.utils import get_allowed_kwarg_only_overrides
from vllm.utils import get_allowed_kwarg_only_overrides, print_warning_once
from .data import LLMInputs
@ -235,9 +235,9 @@ class InputRegistry:
num_tokens = seq_data.prompt_token_ids
if len(num_tokens) < seq_len:
if is_encoder_data:
logger.warning(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead.", seq_len, len(num_tokens))
print_warning_once(
f"Expected at least {seq_len} dummy encoder tokens for "
f"profiling, but found {len(num_tokens)} tokens instead.")
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "