mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 15:56:16 +08:00
Support loading transformers models with named parameters (#16868)
Signed-off-by: Alex <alexwu@character.ai>
This commit is contained in:
parent
dcbac4cb4b
commit
6e74fd4945
@ -166,6 +166,9 @@ class TransformersModel(nn.Module):
|
|||||||
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||||
self.init_buffers(self.model)
|
self.init_buffers(self.model)
|
||||||
|
|
||||||
|
# Initialize parameters
|
||||||
|
self.init_parameters(self.model)
|
||||||
|
|
||||||
# Move remaining meta tensors to device (should happen last)
|
# Move remaining meta tensors to device (should happen last)
|
||||||
self.meta_to_empty(self.model)
|
self.meta_to_empty(self.model)
|
||||||
|
|
||||||
@ -298,6 +301,25 @@ class TransformersModel(nn.Module):
|
|||||||
for child in module.children():
|
for child in module.children():
|
||||||
self.init_buffers(child)
|
self.init_buffers(child)
|
||||||
|
|
||||||
|
def init_parameters(self, module: nn.Module):
|
||||||
|
"""
|
||||||
|
If a `parameter` is on the `meta` device, then its parent
|
||||||
|
`module` is the original module created by:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.device("meta"):
|
||||||
|
self.model: PreTrainedModel = AutoModel.from_config(...)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
for name, param in module.named_parameters(recurse=False):
|
||||||
|
if param.device == torch.device("meta"):
|
||||||
|
new_param = nn.Parameter(
|
||||||
|
torch.empty_like(param.data,
|
||||||
|
device=self.device_config.device))
|
||||||
|
setattr(module, name, new_param)
|
||||||
|
for child in module.children():
|
||||||
|
self.init_parameters(child)
|
||||||
|
|
||||||
def meta_to_empty(self, module: nn.Module):
|
def meta_to_empty(self, module: nn.Module):
|
||||||
tensors = list(chain(module.buffers(), module.parameters()))
|
tensors = list(chain(module.buffers(), module.parameters()))
|
||||||
if tensors and all(t.device == torch.device("meta") for t in tensors):
|
if tensors and all(t.device == torch.device("meta") for t in tensors):
|
||||||
@ -342,6 +364,7 @@ class TransformersModel(nn.Module):
|
|||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
loaded_params = set[str]()
|
loaded_params = set[str]()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
# Use "model" instead of base_model_prefix because
|
# Use "model" instead of base_model_prefix because
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user