make get_model_arch_config not classmethod

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu 2025-12-22 23:12:50 -08:00
parent 5bde69c2b9
commit 1c3db5611a
2 changed files with 9 additions and 8 deletions

View File

@ -483,7 +483,11 @@ def dummy_hf_overrides(
"num_kv_shared_layers": 1, "num_kv_shared_layers": 1,
} }
model_arch_config = ModelConfig.get_model_arch_config(hf_config, text_config) class DummyConfig:
hf_config = hf_config
hf_text_config = text_config
model_arch_config = ModelConfig.get_model_arch_config(DummyConfig)
# Only set MoE related config when the model has MoE layers. # Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls. # Otherwise all models detected as MoE by _get_transformers_backend_cls.
if model_arch_config.num_experts > 0: if model_arch_config.num_experts > 0:

View File

@ -484,9 +484,7 @@ class ModelConfig:
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision self.model, hf_token=self.hf_token, revision=self.revision
) )
self.model_arch_config = self.get_model_arch_config( self.model_arch_config = self.get_model_arch_config()
self.hf_config, self.hf_text_config
)
architectures = self.architectures architectures = self.architectures
registry = self.registry registry = self.registry
@ -604,14 +602,13 @@ class ModelConfig:
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config() self._verify_bnb_config()
@classmethod
def get_model_arch_config( def get_model_arch_config(
cls, hf_config, hf_text_config self,
) -> ModelArchitectureConfig: ) -> ModelArchitectureConfig:
convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get( convertor_cls = MODEL_ARCH_CONFIG_CONVERTORS.get(
hf_config.model_type, ModelArchConfigConvertorBase self.hf_config.model_type, ModelArchConfigConvertorBase
) )
convertor = convertor_cls(hf_config, hf_text_config) convertor = convertor_cls(self.hf_config, self.hf_text_config)
return convertor.convert() return convertor.convert()
@field_validator("tokenizer", "max_model_len", mode="wrap") @field_validator("tokenizer", "max_model_len", mode="wrap")