mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:15:36 +08:00
Transformers backend tweaks (#17365)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
88ad9ec6b2
commit
900edfa8d4
@ -15,7 +15,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Wrapper around `transformers` models"""
|
"""Wrapper around `transformers` models"""
|
||||||
import re
|
import re
|
||||||
from itertools import chain
|
|
||||||
from typing import Iterable, Literal, Optional, Union
|
from typing import Iterable, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -166,12 +165,9 @@ class TransformersModel(nn.Module):
|
|||||||
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||||
self.init_buffers(self.model)
|
self.init_buffers(self.model)
|
||||||
|
|
||||||
# Initialize parameters
|
# Initialize any parameters that have not had their modules replaced
|
||||||
self.init_parameters(self.model)
|
self.init_parameters(self.model)
|
||||||
|
|
||||||
# Move remaining meta tensors to device (should happen last)
|
|
||||||
self.meta_to_empty(self.model)
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||||
config.hidden_size))
|
config.hidden_size))
|
||||||
@ -296,6 +292,14 @@ class TransformersModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
for name, buffer in module.named_buffers(recurse=False):
|
for name, buffer in module.named_buffers(recurse=False):
|
||||||
if buffer.device == torch.device("meta"):
|
if buffer.device == torch.device("meta"):
|
||||||
|
if module == self.model:
|
||||||
|
logger.warning(
|
||||||
|
"To initialize buffers correctly, we instantiate the "
|
||||||
|
"parent module and and extract the value of the "
|
||||||
|
"buffer from it. In this case, the parent module is "
|
||||||
|
"the base model. Instantiating the entire model here "
|
||||||
|
"risks GPU OOM. Could this buffer be moved to a child "
|
||||||
|
"module?")
|
||||||
new_buffer = getattr(type(module)(self.config), name)
|
new_buffer = getattr(type(module)(self.config), name)
|
||||||
setattr(module, name, new_buffer)
|
setattr(module, name, new_buffer)
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
@ -320,14 +324,6 @@ class TransformersModel(nn.Module):
|
|||||||
for child in module.children():
|
for child in module.children():
|
||||||
self.init_parameters(child)
|
self.init_parameters(child)
|
||||||
|
|
||||||
def meta_to_empty(self, module: nn.Module):
|
|
||||||
tensors = list(chain(module.buffers(), module.parameters()))
|
|
||||||
if tensors and all(t.device == torch.device("meta") for t in tensors):
|
|
||||||
module.to_empty(device=self.device_config.device)
|
|
||||||
return # We can stop recursing because to_empty is recursive
|
|
||||||
for child in module.children():
|
|
||||||
self.meta_to_empty(child)
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Module:
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
return self.model.get_input_embeddings()
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user