Transformers backend tweaks (#17365)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-29 17:08:03 +01:00 committed by GitHub
parent 88ad9ec6b2
commit 900edfa8d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,6 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from itertools import chain
from typing import Iterable, Literal, Optional, Union
import torch
@ -166,12 +165,9 @@ class TransformersModel(nn.Module):
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Initialize parameters
# Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model)
# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
@ -296,6 +292,14 @@ class TransformersModel(nn.Module):
"""
for name, buffer in module.named_buffers(recurse=False):
if buffer.device == torch.device("meta"):
if module == self.model:
logger.warning(
"To initialize buffers correctly, we instantiate the "
"parent module and and extract the value of the "
"buffer from it. In this case, the parent module is "
"the base model. Instantiating the entire model here "
"risks GPU OOM. Could this buffer be moved to a child "
"module?")
new_buffer = getattr(type(module)(self.config), name)
setattr(module, name, new_buffer)
for child in module.children():
@ -320,14 +324,6 @@ class TransformersModel(nn.Module):
for child in module.children():
self.init_parameters(child)
def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
module.to_empty(device=self.device_config.device)
return # We can stop recursing because to_empty is recursive
for child in module.children():
self.meta_to_empty(child)
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()