diff --git a/tests/test_config.py b/tests/test_config.py index 715ef09dd3075..5d5c4453d30d2 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -438,3 +438,31 @@ def test_load_config_pt_load_map_location(pt_load_map_location): config = VllmConfig(load_config=load_config) assert config.load_config.pt_load_map_location == pt_load_map_location + + +@pytest.mark.parametrize( + ("model_id", "max_model_len", "expected_max_len", "should_raise"), [ + ("BAAI/bge-reranker-base", None, 512, False), + ("BAAI/bge-reranker-base", 256, 256, False), + ("BAAI/bge-reranker-base", 513, 512, True), + ]) +def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len, + should_raise): + """Test get_and_verify_max_len with different configurations.""" + model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + if should_raise: + with pytest.raises(ValueError): + model_config.get_and_verify_max_len(max_model_len) + else: + actual_max_len = model_config.get_and_verify_max_len(max_model_len) + assert actual_max_len == expected_max_len diff --git a/vllm/config.py b/vllm/config.py index 7217a659a5595..4c0c575ec3b55 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1429,25 +1429,19 @@ class ModelConfig: return getattr(self.hf_config, "matryoshka_dimensions", None) def get_and_verify_max_len(self, max_model_len: int): + tokenizer_config = try_get_tokenizer_config( + self.tokenizer, + trust_remote_code=self.trust_remote_code, + revision=self.tokenizer_revision) max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, + tokenizer_config=tokenizer_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), spec_target_max_model_len=self.spec_target_max_model_len, encoder_config=self.encoder_config) - - tokenizer_config = try_get_tokenizer_config( - self.tokenizer, - trust_remote_code=self.trust_remote_code, - revision=self.tokenizer_revision) - - if tokenizer_config is None: - return max_model_len - - model_max_length = tokenizer_config.get("model_max_length", - max_model_len) - max_model_len = min(max_model_len, model_max_length) + logger.info("Using max model len %s", max_model_len) return max_model_len @@ -3283,6 +3277,7 @@ def _get_and_verify_dtype( def _get_and_verify_max_len( hf_config: PretrainedConfig, + tokenizer_config: Optional[dict], max_model_len: Optional[int], disable_sliding_window: bool, sliding_window_len: Optional[Union[int, list[Optional[int]]]], @@ -3309,7 +3304,7 @@ def _get_and_verify_max_len( "max_seq_length", "seq_len", ] - # Choose the smallest "max_length" from the possible keys. + # Choose the smallest "max_length" from the possible keys max_len_key = None for key in possible_keys: max_len = getattr(hf_config, key, None) @@ -3332,6 +3327,13 @@ def _get_and_verify_max_len( derived_max_model_len = min(derived_max_model_len, sliding_window_len_min) + # Consider model_max_length in tokenizer_config + if tokenizer_config: + tokenizer_model_max_length = tokenizer_config.get( + "model_max_length", derived_max_model_len) + derived_max_model_len = min(derived_max_model_len, + tokenizer_model_max_length) + # If none of the keys were found in the config, use a default and # log a warning. if derived_max_model_len == float("inf"):