mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 06:55:01 +08:00
Transformers backend tweaks (#17365)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
88ad9ec6b2
commit
900edfa8d4
@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` models"""
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import Iterable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -166,12 +165,9 @@ class TransformersModel(nn.Module):
|
||||
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||
self.init_buffers(self.model)
|
||||
|
||||
# Initialize parameters
|
||||
# Initialize any parameters that have not had their modules replaced
|
||||
self.init_parameters(self.model)
|
||||
|
||||
# Move remaining meta tensors to device (should happen last)
|
||||
self.meta_to_empty(self.model)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
@ -296,6 +292,14 @@ class TransformersModel(nn.Module):
|
||||
"""
|
||||
for name, buffer in module.named_buffers(recurse=False):
|
||||
if buffer.device == torch.device("meta"):
|
||||
if module == self.model:
|
||||
logger.warning(
|
||||
"To initialize buffers correctly, we instantiate the "
|
||||
"parent module and and extract the value of the "
|
||||
"buffer from it. In this case, the parent module is "
|
||||
"the base model. Instantiating the entire model here "
|
||||
"risks GPU OOM. Could this buffer be moved to a child "
|
||||
"module?")
|
||||
new_buffer = getattr(type(module)(self.config), name)
|
||||
setattr(module, name, new_buffer)
|
||||
for child in module.children():
|
||||
@ -320,14 +324,6 @@ class TransformersModel(nn.Module):
|
||||
for child in module.children():
|
||||
self.init_parameters(child)
|
||||
|
||||
def meta_to_empty(self, module: nn.Module):
|
||||
tensors = list(chain(module.buffers(), module.parameters()))
|
||||
if tensors and all(t.device == torch.device("meta") for t in tensors):
|
||||
module.to_empty(device=self.device_config.device)
|
||||
return # We can stop recursing because to_empty is recursive
|
||||
for child in module.children():
|
||||
self.meta_to_empty(child)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user