mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-14 22:34:34 +08:00
[Bugfix][Model] Make Olmo2Model weight loading return loaded weights (#18504)
Signed-off-by: Shane A <shanea@allenai.org>
This commit is contained in:
parent
cf5984b2fe
commit
51797775c3
@ -314,7 +314,8 @@ class Olmo2Model(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -325,6 +326,7 @@ class Olmo2Model(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
@ -347,6 +349,8 @@ class Olmo2Model(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Olmo2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user