mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 03:02:15 +08:00
[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config (#19660)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
367871a469
commit
b692e9cd07
@ -438,3 +438,31 @@ def test_load_config_pt_load_map_location(pt_load_map_location):
|
|||||||
config = VllmConfig(load_config=load_config)
|
config = VllmConfig(load_config=load_config)
|
||||||
|
|
||||||
assert config.load_config.pt_load_map_location == pt_load_map_location
|
assert config.load_config.pt_load_map_location == pt_load_map_location
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model_id", "max_model_len", "expected_max_len", "should_raise"), [
|
||||||
|
("BAAI/bge-reranker-base", None, 512, False),
|
||||||
|
("BAAI/bge-reranker-base", 256, 256, False),
|
||||||
|
("BAAI/bge-reranker-base", 513, 512, True),
|
||||||
|
])
|
||||||
|
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
|
||||||
|
should_raise):
|
||||||
|
"""Test get_and_verify_max_len with different configurations."""
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_id,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_raise:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model_config.get_and_verify_max_len(max_model_len)
|
||||||
|
else:
|
||||||
|
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
|
||||||
|
assert actual_max_len == expected_max_len
|
||||||
|
|||||||
@ -1429,25 +1429,19 @@ class ModelConfig:
|
|||||||
return getattr(self.hf_config, "matryoshka_dimensions", None)
|
return getattr(self.hf_config, "matryoshka_dimensions", None)
|
||||||
|
|
||||||
def get_and_verify_max_len(self, max_model_len: int):
|
def get_and_verify_max_len(self, max_model_len: int):
|
||||||
|
tokenizer_config = try_get_tokenizer_config(
|
||||||
|
self.tokenizer,
|
||||||
|
trust_remote_code=self.trust_remote_code,
|
||||||
|
revision=self.tokenizer_revision)
|
||||||
max_model_len = _get_and_verify_max_len(
|
max_model_len = _get_and_verify_max_len(
|
||||||
hf_config=self.hf_text_config,
|
hf_config=self.hf_text_config,
|
||||||
|
tokenizer_config=tokenizer_config,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||||
spec_target_max_model_len=self.spec_target_max_model_len,
|
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||||
encoder_config=self.encoder_config)
|
encoder_config=self.encoder_config)
|
||||||
|
logger.info("Using max model len %s", max_model_len)
|
||||||
tokenizer_config = try_get_tokenizer_config(
|
|
||||||
self.tokenizer,
|
|
||||||
trust_remote_code=self.trust_remote_code,
|
|
||||||
revision=self.tokenizer_revision)
|
|
||||||
|
|
||||||
if tokenizer_config is None:
|
|
||||||
return max_model_len
|
|
||||||
|
|
||||||
model_max_length = tokenizer_config.get("model_max_length",
|
|
||||||
max_model_len)
|
|
||||||
max_model_len = min(max_model_len, model_max_length)
|
|
||||||
return max_model_len
|
return max_model_len
|
||||||
|
|
||||||
|
|
||||||
@ -3283,6 +3277,7 @@ def _get_and_verify_dtype(
|
|||||||
|
|
||||||
def _get_and_verify_max_len(
|
def _get_and_verify_max_len(
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
|
tokenizer_config: Optional[dict],
|
||||||
max_model_len: Optional[int],
|
max_model_len: Optional[int],
|
||||||
disable_sliding_window: bool,
|
disable_sliding_window: bool,
|
||||||
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
|
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
|
||||||
@ -3309,7 +3304,7 @@ def _get_and_verify_max_len(
|
|||||||
"max_seq_length",
|
"max_seq_length",
|
||||||
"seq_len",
|
"seq_len",
|
||||||
]
|
]
|
||||||
# Choose the smallest "max_length" from the possible keys.
|
# Choose the smallest "max_length" from the possible keys
|
||||||
max_len_key = None
|
max_len_key = None
|
||||||
for key in possible_keys:
|
for key in possible_keys:
|
||||||
max_len = getattr(hf_config, key, None)
|
max_len = getattr(hf_config, key, None)
|
||||||
@ -3332,6 +3327,13 @@ def _get_and_verify_max_len(
|
|||||||
derived_max_model_len = min(derived_max_model_len,
|
derived_max_model_len = min(derived_max_model_len,
|
||||||
sliding_window_len_min)
|
sliding_window_len_min)
|
||||||
|
|
||||||
|
# Consider model_max_length in tokenizer_config
|
||||||
|
if tokenizer_config:
|
||||||
|
tokenizer_model_max_length = tokenizer_config.get(
|
||||||
|
"model_max_length", derived_max_model_len)
|
||||||
|
derived_max_model_len = min(derived_max_model_len,
|
||||||
|
tokenizer_model_max_length)
|
||||||
|
|
||||||
# If none of the keys were found in the config, use a default and
|
# If none of the keys were found in the config, use a default and
|
||||||
# log a warning.
|
# log a warning.
|
||||||
if derived_max_model_len == float("inf"):
|
if derived_max_model_len == float("inf"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user