diff --git a/vllm/config.py b/vllm/config.py index 3bac36fcbbea..40beace3040c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2954,10 +2954,12 @@ def _get_and_verify_dtype( ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. - config_dtype = getattr(config.get_text_config(), "torch_dtype", None) + config_dtype = getattr(config, "torch_dtype", None) - # Fallback for multi-modal models if the root config + # Fallbacks for multi-modal models if the root config # does not define torch_dtype + if config_dtype is None: + config_dtype = getattr(config.get_text_config(), "torch_dtype", None) if config_dtype is None and hasattr(config, "vision_config"): config_dtype = getattr(config.vision_config, "torch_dtype", None)