From 51797775c3ffc1277be750eb046c8030f8eca280 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 21 May 2025 21:17:03 -0700 Subject: [PATCH] [Bugfix][Model] Make Olmo2Model weight loading return loaded weights (#18504) Signed-off-by: Shane A --- vllm/model_executor/models/olmo2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 0a1fb10c186e5..33adacdae5f5b 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -314,7 +314,8 @@ class Olmo2Model(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -325,6 +326,7 @@ class Olmo2Model(nn.Module): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue @@ -347,6 +349,8 @@ class Olmo2Model(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Olmo2ForCausalLM(nn.Module, SupportsPP):