mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:55:55 +08:00
38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
from torch import nn
|
|
|
|
from vllm.config import LoadConfig, ModelConfig
|
|
from vllm.model_executor.model_loader import (get_model_loader,
|
|
register_model_loader)
|
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
|
|
|
|
|
@register_model_loader("custom_load_format")
|
|
class CustomModelLoader(BaseModelLoader):
|
|
|
|
def __init__(self, load_config: LoadConfig) -> None:
|
|
super().__init__(load_config)
|
|
|
|
def download_model(self, model_config: ModelConfig) -> None:
|
|
pass
|
|
|
|
def load_weights(self, model: nn.Module,
|
|
model_config: ModelConfig) -> None:
|
|
pass
|
|
|
|
|
|
def test_register_model_loader():
|
|
load_config = LoadConfig(load_format="custom_load_format")
|
|
assert isinstance(get_model_loader(load_config), CustomModelLoader)
|
|
|
|
|
|
def test_invalid_model_loader():
|
|
with pytest.raises(ValueError):
|
|
|
|
@register_model_loader("invalid_load_format")
|
|
class InValidModelLoader:
|
|
pass
|