diff --git a/tests/config/test_model_arch_config.py b/tests/config/test_model_arch_config.py index 5531963b05dcd..da699fed040de 100644 --- a/tests/config/test_model_arch_config.py +++ b/tests/config/test_model_arch_config.py @@ -81,7 +81,7 @@ def _assert_model_arch_config( assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype( - model_config.hf_config, model_config.model_id, revision=model_config.revision + model_config.hf_config, model_config.model, revision=model_config.revision ) assert str(torch_dtype) == expected["dtype"] diff --git a/tests/models/utils.py b/tests/models/utils.py index 8e3a86f426d5d..63dd07b827dbb 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -483,8 +483,10 @@ def dummy_hf_overrides( "num_kv_shared_layers": 1, } + _hf_config = hf_config + class DummyConfig: - hf_config = hf_config + hf_config = _hf_config hf_text_config = text_config model_arch_config = ModelConfig.get_model_arch_config(DummyConfig)