mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 16:15:40 +08:00
Init model on GPU to reduce CPU memory footprint (#1796)
This commit is contained in:
parent
665cbcec4b
commit
a8b150c595
@ -87,9 +87,9 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
with _set_default_torch_dtype(model_config.dtype):
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
model = model_class(model_config.hf_config, linear_method)
|
with torch.device("cuda"):
|
||||||
|
model = model_class(model_config.hf_config, linear_method)
|
||||||
if model_config.load_format == "dummy":
|
if model_config.load_format == "dummy":
|
||||||
model = model.cuda()
|
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
@ -97,5 +97,4 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
|||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(model_config.model, model_config.download_dir,
|
model.load_weights(model_config.model, model_config.download_dir,
|
||||||
model_config.load_format, model_config.revision)
|
model_config.load_format, model_config.revision)
|
||||||
model = model.cuda()
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user