mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-06 23:36:32 +08:00
[Model] use AutoWeightsLoader for bart (#18299)
Signed-off-by: calvin chen <120380290@qq.com>
This commit is contained in:
parent
d1fb65bde3
commit
51ba839555
@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsQuant, SupportsV0Only
|
||||
from .utils import maybe_prefix
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@ -700,7 +700,8 @@ class BartDecoder(nn.Module):
|
||||
|
||||
class BartModel(nn.Module, SupportsQuant):
|
||||
_tied_weights_keys = [
|
||||
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@ -763,10 +764,54 @@ class BartModel(nn.Module, SupportsQuant):
|
||||
|
||||
return decoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
other_weights = []
|
||||
loaded_stacked_params = []
|
||||
model_params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if name not in model_params_dict:
|
||||
continue
|
||||
param = model_params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_stacked_params.append(name)
|
||||
break
|
||||
else:
|
||||
if name in model_params_dict:
|
||||
other_weights.append((name, loaded_weight))
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loaded_params = loader.load_weights(other_weights)
|
||||
loaded_params.update(loaded_stacked_params)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
base_model_prefix = "model"
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"decoder.": "model.decoder.",
|
||||
"encoder.": "model.encoder.",
|
||||
"shared.": "model.shared."
|
||||
},
|
||||
orig_to_new_substr={
|
||||
"beta": "bias",
|
||||
"gamma": "weight",
|
||||
"LayerNorm": "layernorm",
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@ -789,7 +834,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
self.lm_head = BartParallelLMHead(config.vocab_size,
|
||||
config.d_model,
|
||||
embed_scale=embed_scale)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
|
||||
@ -828,61 +872,12 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
stacked_params_mapping = {
|
||||
"q_proj": {
|
||||
"param_name": "qkv_proj",
|
||||
"shard_id": "q",
|
||||
},
|
||||
"k_proj": {
|
||||
"param_name": "qkv_proj",
|
||||
"shard_id": "k",
|
||||
},
|
||||
"v_proj": {
|
||||
"param_name": "qkv_proj",
|
||||
"shard_id": "v",
|
||||
},
|
||||
}
|
||||
|
||||
params_mapping = {
|
||||
"beta": "bias",
|
||||
"gamma": "weight",
|
||||
"LayerNorm": "layernorm",
|
||||
}
|
||||
|
||||
def _rename_key(self, key: str):
|
||||
prefix = f"{self.base_model_prefix}."
|
||||
key = key[len(prefix):] if key.startswith(prefix) else key
|
||||
|
||||
for src, dst in self.params_mapping.items():
|
||||
key = key.replace(src, dst)
|
||||
|
||||
return key
|
||||
|
||||
def _rename_stacked_param(
|
||||
self,
|
||||
name: str,
|
||||
) -> tuple[str, Optional[str]]:
|
||||
for key, mapping in self.stacked_params_mapping.items():
|
||||
if key in name:
|
||||
name = name.replace(key, mapping["param_name"])
|
||||
return name, mapping["shard_id"]
|
||||
return name, None
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
model_params_dict = dict(self.model.named_parameters())
|
||||
top_params_dict = dict(self.named_parameters())
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
weights_tuple_list = list(weights)
|
||||
|
||||
shared_embedding_weight = None
|
||||
shared_embedding_shard_id = None
|
||||
|
||||
for name, loaded_weight in weights_tuple_list:
|
||||
|
||||
name = self._rename_key(name)
|
||||
name, shard_id = self._rename_stacked_param(name)
|
||||
|
||||
if ('shared.weight' in name
|
||||
or 'encoder.embed_tokens.weight' in name
|
||||
or 'decoder.embed_tokens.weight' in name
|
||||
@ -890,49 +885,24 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
|
||||
assert shared_embedding_weight is None, (
|
||||
"Conflicting embedding weights.")
|
||||
shared_embedding_weight = loaded_weight
|
||||
shared_embedding_shard_id = shard_id
|
||||
else:
|
||||
# Skip the specific downstream task weight.
|
||||
if name.startswith('cls.'):
|
||||
continue
|
||||
# use Pooler instead.
|
||||
if name.startswith('pooler.'):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in model_params_dict:
|
||||
continue
|
||||
|
||||
param = model_params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if shard_id:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["cls.", "pooler."]),
|
||||
)
|
||||
loaded_params = loader.load_weights(weights_tuple_list,
|
||||
mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
# Assign shared weight values
|
||||
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
|
||||
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if shared_embedding_weight is not None:
|
||||
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.lm_head.weight, shared_embedding_weight)
|
||||
|
||||
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
|
||||
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
self.model.encoder.embed_tokens.weight = self.lm_head.weight
|
||||
self.model.decoder.embed_tokens.weight = self.lm_head.weight
|
||||
loaded_params.update({
|
||||
'model.encoder.embed_tokens.weight', 'lm_head.weight',
|
||||
'model.decoder.embed_tokens.weight'
|
||||
})
|
||||
|
||||
lm_head_in_param = top_params_dict['lm_head.weight']
|
||||
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
||||
assert shared_embedding_weight is not None
|
||||
|
||||
if shared_embedding_shard_id:
|
||||
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
|
||||
shared_embedding_shard_id)
|
||||
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
|
||||
shared_embedding_shard_id)
|
||||
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
|
||||
shared_embedding_shard_id)
|
||||
else:
|
||||
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
|
||||
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
|
||||
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
|
||||
return loaded_params
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user