[core][misc] keep compatibility for old-style classes (#10356)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-15 05:55:51 -08:00 committed by GitHub
parent f2056f726d
commit 3a763ba0c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,18 +94,34 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
model_config = vllm_config.model_config model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
signatures = inspect.signature(model_class.__init__) signatures = inspect.signature(model_class.__init__)
# collect all kw-only parameters all_params = [param.name for param in signatures.parameters.values()]
kw_only_params = [ if "vllm_config" in all_params and "prefix" in all_params:
param.name for param in signatures.parameters.values() # new-style model class
if param.kind == inspect.Parameter.KEYWORD_ONLY return model_class(vllm_config=vllm_config, prefix=prefix)
] msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
assert "vllm_config" in kw_only_params and "prefix" in kw_only_params, \ "input arguments. Possibly you have an old-style model class"
("vLLM model class must accept `vllm_config` and `prefix` as kw-only " " registered from out of tree and it is used for new vLLM version. "
"arguments. Possibly you have an old-style model class registered from " "Check https://docs.vllm.ai/en/latest/design/class_hierarchy.html "
"out of tree and it is used for new vLLM version. " "for the design and update the model class accordingly.")
"Please check https://docs.vllm.ai/en/latest/design/class_hierarchy.html " logger.warning(msg)
"for the design and update the model class accordingly.") logger.warning(
return model_class(vllm_config=vllm_config, prefix=prefix) "Trying to guess the arguments for old-style model class %s",
model_class)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
return model_class(**kwargs)
class BaseModelLoader(ABC): class BaseModelLoader(ABC):