Use helper function instead of looping through attribute names (#29788)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-23 17:31:56 +00:00 committed by GitHub
parent 1339878e13
commit c016c95b45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 14 deletions

View File

@ -1094,11 +1094,10 @@ class ModelConfig:
# The size of inputs_embeds is usually identical to the size
# of the hidden states, however there are exceptions, such as
# embedding models like CLIP and SigLIP
for target_attr in ("projection_dim", "projection_size"):
if hasattr(self.hf_text_config, target_attr):
return getattr(self.hf_text_config, target_attr)
return self.get_hidden_size()
names = ("projection_dim", "projection_size")
return getattr_iter(
self.hf_text_config, names, default_factory=self.get_hidden_size
)
@property
def is_deepseek_mla(self) -> bool:
@ -1231,14 +1230,12 @@ class ModelConfig:
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU."""

View File

@ -9,7 +9,7 @@ import inspect
import json
import pathlib
import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
object: object,
names: Iterable[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
@ -96,7 +100,7 @@ def getattr_iter(
names[0],
)
return getattr(object, name)
return default
return default_factory() if default_factory is not None else default
def contains_object_print(text: str) -> bool: