From c016c95b45765fcd432c533b16f6d17d77cc5f6d Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:31:56 +0000 Subject: [PATCH] Use helper function instead of looping through attribute names (#29788) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/model.py | 19 ++++++++----------- vllm/config/utils.py | 10 +++++++--- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index 6e199adbf3ee6..e26b227de976c 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1094,11 +1094,10 @@ class ModelConfig: # The size of inputs_embeds is usually identical to the size # of the hidden states, however there are exceptions, such as # embedding models like CLIP and SigLIP - for target_attr in ("projection_dim", "projection_size"): - if hasattr(self.hf_text_config, target_attr): - return getattr(self.hf_text_config, target_attr) - - return self.get_hidden_size() + names = ("projection_dim", "projection_size") + return getattr_iter( + self.hf_text_config, names, default_factory=self.get_hidden_size + ) @property def is_deepseek_mla(self) -> bool: @@ -1231,14 +1230,12 @@ class ModelConfig: # For ChatGLM: "multi_query_group_num", ] - for attr in attributes: - num_kv_heads = getattr(self.hf_text_config, attr, None) - if num_kv_heads is not None: - return num_kv_heads - # For non-grouped-query attention models, the number of KV heads is # equal to the number of attention heads. - return self.hf_text_config.num_attention_heads + default_factory = lambda: self.hf_text_config.num_attention_heads + return getattr_iter( + self.hf_text_config, attributes, default_factory=default_factory + ) def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: """Returns the number of KV heads per GPU.""" diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 470296517deb1..614373782d12f 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -9,7 +9,7 @@ import inspect import json import pathlib import textwrap -from collections.abc import Iterable, Mapping, Sequence, Set +from collections.abc import Callable, Iterable, Mapping, Sequence, Set from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar @@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field: def getattr_iter( - object: object, names: Iterable[str], default: Any, warn: bool = False + object: object, + names: Iterable[str], + default: Any | None = None, + default_factory: Callable[[], Any] | None = None, + warn: bool = False, ) -> Any: """ A helper function that retrieves an attribute from an object which may @@ -96,7 +100,7 @@ def getattr_iter( names[0], ) return getattr(object, name) - return default + return default_factory() if default_factory is not None else default def contains_object_print(text: str) -> bool: