Init model on GPU to reduce CPU memory footprint (#1796)

This commit is contained in:
ljss 2023-11-28 03:18:26 +08:00 committed by GitHub
parent 665cbcec4b
commit a8b150c595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -87,9 +87,9 @@ def get_model(model_config: ModelConfig) -> nn.Module:
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config, linear_method)
with torch.device("cuda"):
model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy":
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
@ -97,5 +97,4 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision)
model = model.cuda()
return model.eval()