mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 04:23:05 +08:00
Use helper function instead of looping through attribute names (#29788)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
1339878e13
commit
c016c95b45
@ -1094,11 +1094,10 @@ class ModelConfig:
|
||||
# The size of inputs_embeds is usually identical to the size
|
||||
# of the hidden states, however there are exceptions, such as
|
||||
# embedding models like CLIP and SigLIP
|
||||
for target_attr in ("projection_dim", "projection_size"):
|
||||
if hasattr(self.hf_text_config, target_attr):
|
||||
return getattr(self.hf_text_config, target_attr)
|
||||
|
||||
return self.get_hidden_size()
|
||||
names = ("projection_dim", "projection_size")
|
||||
return getattr_iter(
|
||||
self.hf_text_config, names, default_factory=self.get_hidden_size
|
||||
)
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
@ -1231,14 +1230,12 @@ class ModelConfig:
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
]
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
return num_kv_heads
|
||||
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_text_config.num_attention_heads
|
||||
default_factory = lambda: self.hf_text_config.num_attention_heads
|
||||
return getattr_iter(
|
||||
self.hf_text_config, attributes, default_factory=default_factory
|
||||
)
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
|
||||
@ -9,7 +9,7 @@ import inspect
|
||||
import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Iterable, Mapping, Sequence, Set
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
|
||||
|
||||
def getattr_iter(
|
||||
object: object, names: Iterable[str], default: Any, warn: bool = False
|
||||
object: object,
|
||||
names: Iterable[str],
|
||||
default: Any | None = None,
|
||||
default_factory: Callable[[], Any] | None = None,
|
||||
warn: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
@ -96,7 +100,7 @@ def getattr_iter(
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default
|
||||
return default_factory() if default_factory is not None else default
|
||||
|
||||
|
||||
def contains_object_print(text: str) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user