[Model] use AutoWeightsLoader for bloom (#18300)

Signed-off-by: calvin chen <120380290@qq.com>
This commit is contained in:
Calvin Chen 2025-05-21 00:40:05 +08:00 committed by GitHub
parent f4a8a37465
commit e1f5a71ed7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -229,6 +229,7 @@ class BloomModel(nn.Module):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_dim = config.hidden_size
@ -278,6 +279,38 @@ class BloomModel(nn.Module):
hidden_states = self.ln_f(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
@ -325,35 +358,15 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
weights = _add_transformer_prefix(weights)
return loader.load_weights(weights)
if "query_key_value" in name:
# NOTE: BLOOM's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _add_transformer_prefix(
weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
for name, tensor in weights:
if not name.startswith('transformer.'):
name = 'transformer.' + name
yield name, tensor