From ee9442518d88925e96c36f1a039d61893d224e78 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 Feb 2023 22:51:03 +0000 Subject: [PATCH] Fix get_model --- cacheflow/worker/models/model_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cacheflow/worker/models/model_utils.py b/cacheflow/worker/models/model_utils.py index a98eac0484a08..3878b87e6ea56 100644 --- a/cacheflow/worker/models/model_utils.py +++ b/cacheflow/worker/models/model_utils.py @@ -8,6 +8,7 @@ MODEL_CLASSES = { def get_model(model_name: str) -> nn.Module: - if model_name not in MODEL_CLASSES: - raise ValueError(f'Invalid model name: {model_name}') - return MODEL_CLASSES[model_name].from_pretrained(model_name) + for model_class, model in MODEL_CLASSES.items(): + if model_class in model_name: + return model.from_pretrained(model_name) + raise ValueError(f'Invalid model name: {model_name}')