diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 0a1fb10c186e5..33adacdae5f5b 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -314,7 +314,8 @@ class Olmo2Model(nn.Module): hidden_states = self.norm(hidden_states) return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -325,6 +326,7 @@ class Olmo2Model(nn.Module): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() for name, loaded_weight in weights: if is_pp_missing_parameter(name, self): continue @@ -347,6 +349,8 @@ class Olmo2Model(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params class Olmo2ForCausalLM(nn.Module, SupportsPP):