mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 19:54:55 +08:00
Signed-off-by: Sanger Steel <sangersteel@gmail.com> Co-authored-by: Eta <esyra@coreweave.com>
103 lines
3.1 KiB
Python
103 lines
3.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from typing import Callable
|
|
|
|
import pytest
|
|
|
|
from vllm import LLM, EngineArgs
|
|
from vllm.distributed import cleanup_dist_env_and_memory
|
|
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
|
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
|
from vllm.v1.executor.abstract import UniProcExecutor
|
|
from vllm.worker.worker_base import WorkerWrapperBase
|
|
|
|
MODEL_REF = "facebook/opt-125m"
|
|
|
|
|
|
@pytest.fixture()
|
|
def model_ref():
|
|
return MODEL_REF
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def allow_insecure_serialization(monkeypatch):
|
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def cleanup():
|
|
cleanup_dist_env_and_memory(shutdown_ray=True)
|
|
|
|
|
|
@pytest.fixture()
|
|
def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path):
|
|
|
|
def noop(*args, **kwargs):
|
|
return None
|
|
|
|
args = EngineArgs(model=model_ref)
|
|
tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors")
|
|
|
|
monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop)
|
|
|
|
tensorizer_mod.tensorize_vllm_model(args, tc)
|
|
yield tmp_path
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def tensorizer_config():
|
|
config = TensorizerConfig(tensorizer_uri="vllm")
|
|
return config
|
|
|
|
|
|
@pytest.fixture()
|
|
def model_path(model_ref, tmp_path):
|
|
yield tmp_path / model_ref / "model.tensors"
|
|
|
|
|
|
def assert_from_collective_rpc(engine: LLM, closure: Callable,
|
|
closure_kwargs: dict):
|
|
res = engine.collective_rpc(method=closure, kwargs=closure_kwargs)
|
|
return all(res)
|
|
|
|
|
|
# This is an object pulled from tests/v1/engine/test_engine_core.py
|
|
# Modified to strip the `load_model` method from its `_init_executor`
|
|
# method. It's purely used as a dummy utility to run methods that test
|
|
# Tensorizer functionality
|
|
class DummyExecutor(UniProcExecutor):
|
|
|
|
def _init_executor(self) -> None:
|
|
"""Initialize the worker and load the model.
|
|
"""
|
|
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
|
|
rpc_rank=0)
|
|
distributed_init_method = get_distributed_init_method(
|
|
get_ip(), get_open_port())
|
|
local_rank = 0
|
|
# set local rank as the device index if specified
|
|
device_info = self.vllm_config.device_config.device.__str__().split(
|
|
":")
|
|
if len(device_info) > 1:
|
|
local_rank = int(device_info[1])
|
|
rank = 0
|
|
is_driver_worker = True
|
|
kwargs = dict(
|
|
vllm_config=self.vllm_config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
is_driver_worker=is_driver_worker,
|
|
)
|
|
self.collective_rpc("init_worker", args=([kwargs], ))
|
|
self.collective_rpc("init_device")
|
|
|
|
@property
|
|
def max_concurrent_batches(self) -> int:
|
|
return 2
|
|
|
|
def shutdown(self):
|
|
if hasattr(self, 'thread_pool'):
|
|
self.thread_pool.shutdown(wait=False)
|