Improve Transformers backend model loading QoL (#17039)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-23 15:33:51 +01:00 committed by GitHub
parent af869f6dff
commit 8e630d680e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,7 +55,10 @@ def resolve_transformers_arch(model_config: ModelConfig,
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>", # "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# }, # },
auto_modules = { auto_modules = {
name: get_class_from_dynamic_module(module, model_config.model) name:
get_class_from_dynamic_module(module,
model_config.model,
revision=model_config.revision)
for name, module in sorted(auto_map.items(), key=lambda x: x[0]) for name, module in sorted(auto_map.items(), key=lambda x: x[0])
} }
custom_model_module = auto_modules.get("AutoModel") custom_model_module = auto_modules.get("AutoModel")
@ -97,10 +100,10 @@ def get_model_architecture(
architectures = ["QuantMixtralForCausalLM"] architectures = ["QuantMixtralForCausalLM"]
vllm_supported_archs = ModelRegistry.get_supported_archs() vllm_supported_archs = ModelRegistry.get_supported_archs()
is_vllm_supported = any(arch in vllm_supported_archs vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures) for arch in architectures)
if (not is_vllm_supported if (model_config.model_impl == ModelImpl.TRANSFORMERS or
or model_config.model_impl == ModelImpl.TRANSFORMERS): model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures) architectures = resolve_transformers_arch(model_config, architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures) model_cls, arch = ModelRegistry.resolve_model_cls(architectures)