Use helper function instead of looping through attribute names (#29788)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-23 17:31:56 +00:00 committed by GitHub
parent 1339878e13
commit c016c95b45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 14 deletions

View File

@ -1094,11 +1094,10 @@ class ModelConfig:
# The size of inputs_embeds is usually identical to the size # The size of inputs_embeds is usually identical to the size
# of the hidden states, however there are exceptions, such as # of the hidden states, however there are exceptions, such as
# embedding models like CLIP and SigLIP # embedding models like CLIP and SigLIP
for target_attr in ("projection_dim", "projection_size"): names = ("projection_dim", "projection_size")
if hasattr(self.hf_text_config, target_attr): return getattr_iter(
return getattr(self.hf_text_config, target_attr) self.hf_text_config, names, default_factory=self.get_hidden_size
)
return self.get_hidden_size()
@property @property
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
@ -1231,14 +1230,12 @@ class ModelConfig:
# For ChatGLM: # For ChatGLM:
"multi_query_group_num", "multi_query_group_num",
] ]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is # For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads. # equal to the number of attention heads.
return self.hf_text_config.num_attention_heads default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU.""" """Returns the number of KV heads per GPU."""

View File

@ -9,7 +9,7 @@ import inspect
import json import json
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter( def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False object: object,
names: Iterable[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
) -> Any: ) -> Any:
""" """
A helper function that retrieves an attribute from an object which may A helper function that retrieves an attribute from an object which may
@ -96,7 +100,7 @@ def getattr_iter(
names[0], names[0],
) )
return getattr(object, name) return getattr(object, name)
return default return default_factory() if default_factory is not None else default
def contains_object_print(text: str) -> bool: def contains_object_print(text: str) -> bool: