mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 07:29:38 +08:00
[Model] Use autoweightloader for dbrx (#18251)
Signed-off-by: learner0810 <zhongjun.li@daocloud.io>
This commit is contained in:
parent
a5f8c111c2
commit
87d871470d
@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -319,6 +319,7 @@ class DbrxModel(nn.Module):
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
@ -364,6 +365,55 @@ class DbrxModel(nn.Module):
|
||||
hidden_states = self.norm_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
expert_params_mapping = [(
|
||||
"w13" if weight_name in ["w1", "v1"] else "w2",
|
||||
f"mlp.{weight_name}",
|
||||
) for weight_name in ["w1", "v1", "w2"]]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if name.endswith(("w1", "w2", "v1")):
|
||||
name = name + "_weight"
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, weight_name, name)
|
||||
break
|
||||
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
@ -417,49 +467,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
expert_params_mapping = [(
|
||||
"w13" if weight_name in ["w1", "v1"] else "w2",
|
||||
f"mlp.{weight_name}",
|
||||
) for weight_name in ["w1", "v1", "w2"]]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if (self.quant_config is not None and
|
||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||
loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
|
||||
if name.endswith(("w1", "w2", "v1")):
|
||||
name = name + "_weight"
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, weight_name, name)
|
||||
break
|
||||
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user