diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 2652cc067549c..4cddf8c360ee6 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,4 +1,5 @@ """Utilities for selecting and loading models.""" +import contextlib from typing import Type import torch @@ -30,6 +31,15 @@ _MODEL_REGISTRY = { } +@contextlib.contextmanager +def _set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) - torch.set_default_dtype(model_config.dtype) - - # Create a model instance. - # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config) - if model_config.use_dummy_weights: - model = model.cuda() - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - else: - # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.use_np_weights) - model = model.cuda() + with _set_default_torch_dtype(model_config.dtype): + # Create a model instance. + # The weights will be initialized as empty tensors. + model = model_class(model_config.hf_config) + if model_config.use_dummy_weights: + model = model.cuda() + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, model_config.download_dir, + model_config.use_np_weights) + model = model.cuda() return model.eval()