Set torch default dtype in a context manager (#971)

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Antoni Baum 2023-09-06 23:39:37 -07:00 committed by GitHub
parent 320a622ec4
commit 005ba458b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,5 @@
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib
from typing import Type from typing import Type
import torch import torch
@ -30,6 +31,15 @@ _MODEL_REGISTRY = {
} }
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
@ -42,19 +52,18 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
torch.set_default_dtype(model_config.dtype) with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# Create a model instance. # The weights will be initialized as empty tensors.
# The weights will be initialized as empty tensors. model = model_class(model_config.hf_config)
model = model_class(model_config.hf_config) if model_config.use_dummy_weights:
if model_config.use_dummy_weights: model = model.cuda()
model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights.
# random values to the weights. initialize_dummy_weights(model)
initialize_dummy_weights(model) else:
else: # Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir,
model.load_weights(model_config.model, model_config.download_dir, model_config.use_np_weights)
model_config.use_np_weights) model = model.cuda()
model = model.cuda()
return model.eval() return model.eval()