# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest from torch import nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.model_executor.model_loader import get_model_loader, register_model_loader from vllm.model_executor.model_loader.base_loader import BaseModelLoader @register_model_loader("custom_load_format") class CustomModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig) -> None: super().__init__(load_config) def download_model(self, model_config: ModelConfig) -> None: pass def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: pass def test_register_model_loader(): load_config = LoadConfig(load_format="custom_load_format") assert isinstance(get_model_loader(load_config), CustomModelLoader) def test_invalid_model_loader(): with pytest.raises(ValueError): @register_model_loader("invalid_load_format") class InValidModelLoader: pass