diff --git a/vllm/config.py b/vllm/config.py index ddaff0710a3b8..d475cdbcb1c7c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -532,16 +532,12 @@ class ModelConfig: self.config_format = ConfigFormat(self.config_format) hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, self.revision, - self.code_revision, self.config_format) - - if hf_overrides_kw: - logger.debug("Overriding HF config with %s", hf_overrides_kw) - hf_config.update(hf_overrides_kw) - if hf_overrides_fn: - logger.debug("Overriding HF config with %s", hf_overrides_fn) - hf_config = hf_overrides_fn(hf_config) - + self.trust_remote_code, + self.revision, + self.code_revision, + self.config_format, + hf_overrides_kw=hf_overrides_kw, + hf_overrides_fn=hf_overrides_fn) self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) @@ -5052,4 +5048,4 @@ class SpeechToTextConfig: @property def allow_audio_chunking(self) -> bool: - return self.min_energy_split_window_size is not None \ No newline at end of file + return self.min_energy_split_window_size is not None diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 411c970b2f0d8..cf3f519b027ca 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -305,6 +305,9 @@ def get_config( revision: Optional[str] = None, code_revision: Optional[str] = None, config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides_kw: Optional[dict[str, Any]] = None, + hf_overrides_fn: Optional[Callable[[PretrainedConfig], + PretrainedConfig]] = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -423,6 +426,13 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) + if hf_overrides_kw: + logger.debug("Overriding HF config with %s", hf_overrides_kw) + config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.debug("Overriding HF config with %s", hf_overrides_fn) + config = hf_overrides_fn(config) + patch_rope_scaling(config) if trust_remote_code: