[Misc] Check that the model can be inspected upon registration (#13743)

This commit is contained in:
Cyrus Leung 2025-02-25 16:18:19 +08:00 committed by GitHub
parent 03f48b3db6
commit 6724e79164
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -347,6 +347,10 @@ class _ModelRegistry:
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if not isinstance(model_arch, str):
msg = f"`model_arch` should be a string, not a {type(model_arch)}"
raise TypeError(msg)
if model_arch in self.models:
logger.warning(
"Model architecture %s is already registered, and will be "
@ -360,8 +364,18 @@ class _ModelRegistry:
raise ValueError(msg)
model = _LazyRegisteredModel(*split_str)
else:
try:
model.inspect_model_cls()
except Exception as exc:
msg = f"Unable to inspect model {model_cls}"
raise RuntimeError(msg) from exc
elif isinstance(model_cls, type) and issubclass(model_cls, nn.Module):
model = _RegisteredModel.from_model_cls(model_cls)
else:
msg = ("`model_cls` should be a string or PyTorch model class, "
f"not a {type(model_arch)}")
raise TypeError(msg)
self.models[model_arch] = model