diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index ad7c07dc8cd2..7b946ad6aac7 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,7 +15,6 @@ # limitations under the License. """Wrapper around `transformers` models""" import re -from itertools import chain from typing import Iterable, Literal, Optional, Union import torch @@ -166,12 +165,9 @@ class TransformersModel(nn.Module): # Initialize buffers (e.g. rotary embedding inverse frequency) self.init_buffers(self.model) - # Initialize parameters + # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) - # Move remaining meta tensors to device (should happen last) - self.meta_to_empty(self.model) - self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) @@ -296,6 +292,14 @@ class TransformersModel(nn.Module): """ for name, buffer in module.named_buffers(recurse=False): if buffer.device == torch.device("meta"): + if module == self.model: + logger.warning( + "To initialize buffers correctly, we instantiate the " + "parent module and and extract the value of the " + "buffer from it. In this case, the parent module is " + "the base model. Instantiating the entire model here " + "risks GPU OOM. Could this buffer be moved to a child " + "module?") new_buffer = getattr(type(module)(self.config), name) setattr(module, name, new_buffer) for child in module.children(): @@ -320,14 +324,6 @@ class TransformersModel(nn.Module): for child in module.children(): self.init_parameters(child) - def meta_to_empty(self, module: nn.Module): - tensors = list(chain(module.buffers(), module.parameters())) - if tensors and all(t.device == torch.device("meta") for t in tensors): - module.to_empty(device=self.device_config.device) - return # We can stop recursing because to_empty is recursive - for child in module.children(): - self.meta_to_empty(child) - def get_input_embeddings(self) -> nn.Module: return self.model.get_input_embeddings()