mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 13:44:58 +08:00
[Bugfix] Limit the default value of max_model_len when it is not specified by users (#27556)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
7c2bdb83dc
commit
a3e8611da5
@ -2112,20 +2112,13 @@ def _get_and_verify_max_len(
|
|||||||
if encoder_config and "max_seq_length" in encoder_config:
|
if encoder_config and "max_seq_length" in encoder_config:
|
||||||
derived_max_model_len = encoder_config["max_seq_length"]
|
derived_max_model_len = encoder_config["max_seq_length"]
|
||||||
|
|
||||||
# If the user specified a max length, make sure it is smaller than the
|
# If the user didn't specify `max_model_len`, then use that derived from
|
||||||
# derived length from the HF model config.
|
# the model config as a default value.
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = int(derived_max_model_len)
|
max_model_len = int(derived_max_model_len)
|
||||||
if current_platform.is_tpu():
|
max_model_len = current_platform.check_max_model_len(max_model_len)
|
||||||
logger.warning(
|
# If the user specified a max length, make sure it is smaller than the
|
||||||
"--max-model-len is not specified, "
|
# derived length from the HF model config.
|
||||||
"it's currently using model's default length %s, "
|
|
||||||
"which might be too large."
|
|
||||||
"Please input with --max-model-len based on your "
|
|
||||||
"request input length and output length, to avoid "
|
|
||||||
"unnecessary degradation.",
|
|
||||||
max_model_len,
|
|
||||||
)
|
|
||||||
elif max_model_len > derived_max_model_len:
|
elif max_model_len > derived_max_model_len:
|
||||||
# Some models might have a separate key for specifying model_max_length
|
# Some models might have a separate key for specifying model_max_length
|
||||||
# that will be bigger than derived_max_model_len. We compare user input
|
# that will be bigger than derived_max_model_len. We compare user input
|
||||||
|
|||||||
@ -608,6 +608,13 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||||
|
"""
|
||||||
|
Check max_model_len for the current platform.
|
||||||
|
"""
|
||||||
|
return max_model_len
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -251,6 +251,22 @@ class TpuPlatform(Platform):
|
|||||||
def use_sync_weight_loader(cls) -> bool:
|
def use_sync_weight_loader(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_max_model_len(cls, max_model_len: int) -> int:
|
||||||
|
"""
|
||||||
|
Check max_model_len for the current platform.
|
||||||
|
"""
|
||||||
|
logger.warning(
|
||||||
|
"--max-model-len is not specified, "
|
||||||
|
"it's currently using model's default length %d, "
|
||||||
|
"which might be too large."
|
||||||
|
"Please input with --max-model-len based on your "
|
||||||
|
"request input length and output length, to avoid "
|
||||||
|
"unnecessary degradation.",
|
||||||
|
max_model_len,
|
||||||
|
)
|
||||||
|
return max_model_len
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
|
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user