mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 17:23:09 +08:00
[Model] use AutoWeightsLoader for bloom (#18300)
Signed-off-by: calvin chen <120380290@qq.com>
This commit is contained in:
parent
f4a8a37465
commit
e1f5a71ed7
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -229,6 +229,7 @@ class BloomModel(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
@ -278,6 +279,38 @@ class BloomModel(nn.Module):
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE: BLOOM's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
|
||||
|
||||
@ -325,35 +358,15 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
|
||||
weights = _add_transformer_prefix(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE: BLOOM's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
def _add_transformer_prefix(
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
for name, tensor in weights:
|
||||
if not name.startswith('transformer.'):
|
||||
name = 'transformer.' + name
|
||||
yield name, tensor
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user