[V1] LoRA - Enable more V1 tests (#14315)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-03-05 22:55:42 -05:00 committed by GitHub
parent f5f7f00cd9
commit 3dbd2d813a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 57 additions and 8 deletions

View File

@ -55,7 +55,6 @@ def v1(run_with_both_engines_lora):
pass
@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
@ -75,7 +74,6 @@ def test_chatglm3_lora(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4(chatglm3_lora_files):
@ -97,7 +95,6 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):

View File

@ -10,6 +10,14 @@ from vllm.platforms import current_platform
MODEL_PATH = "mistralai/Mixtral-8x7B-Instruct-v0.1"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int,
prompts: list[str]) -> list[str]:

View File

@ -12,6 +12,14 @@ from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@dataclass
class TestConfig:
model_path: str

View File

@ -4,6 +4,7 @@ import shutil
from os import path
from tempfile import TemporaryDirectory
import pytest
import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
@ -21,6 +22,14 @@ VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
def llama3_1_8b_chess_lora_path():
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")

View File

@ -3,18 +3,45 @@
import os
import random
import tempfile
from typing import Union
from unittest.mock import patch
import pytest
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu_worker import Worker as V1Worker
from vllm.worker.worker import Worker
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
def set_active_loras(worker: Union[Worker, V1Worker],
lora_requests: list[LoRARequest]):
lora_mapping = LoRAMapping([], [])
if isinstance(worker, Worker):
# v0 case
worker.model_runner.set_active_loras(lora_requests, lora_mapping)
else:
# v1 case
worker.model_runner.lora_manager.set_active_adapters(
lora_requests, lora_mapping)
worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker
vllm_config = VllmConfig(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
@ -40,16 +67,17 @@ def test_worker_apply_lora(sql_lora_files):
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
max_loras=32),
)
worker = Worker(
worker = worker_cls(
vllm_config=vllm_config,
local_rank=0,
rank=0,
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
)
worker.init_device()
worker.load_model()
worker.model_runner.set_active_loras([], LoRAMapping([], []))
set_active_loras(worker, [])
assert worker.list_loras() == set()
n_loras = 32
@ -57,7 +85,7 @@ def test_worker_apply_lora(sql_lora_files):
LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
]
worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
set_active_loras(worker, lora_requests)
assert worker.list_loras() == {
lora_request.lora_int_id
for lora_request in lora_requests
@ -69,8 +97,7 @@ def test_worker_apply_lora(sql_lora_files):
k=random.randint(1, n_loras))
random.shuffle(iter_lora_requests)
iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
worker.model_runner.set_active_loras(iter_lora_requests,
LoRAMapping([], []))
set_active_loras(worker, lora_requests)
assert worker.list_loras().issuperset(
{lora_request.lora_int_id
for lora_request in iter_lora_requests})