[LoRA] Cleanup LoRA unused code (#29611)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Jee Jee Li 2025-11-29 14:52:58 +08:00 committed by GitHub
parent 4a80ad0a25
commit 39e63dec7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 126 additions and 173 deletions

View File

@ -46,7 +46,6 @@ def create_test_prompts(
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003],
),
LoRARequest("sql-lora", 1, lora_path),
),
@ -57,7 +56,6 @@ def create_test_prompts(
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003],
),
LoRARequest("sql-lora2", 2, lora_path),
),
@ -98,7 +96,7 @@ def initialize_engine() -> LLMEngine:
# use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs(
model="meta-llama/Llama-2-7b-hf",
model="meta-llama/Llama-3.2-3B-Instruct",
enable_lora=True,
max_loras=1,
max_lora_rank=8,
@ -111,7 +109,7 @@ def initialize_engine() -> LLMEngine:
def main():
"""Main function that sets up and runs the prompt processing."""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
lora_path = snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)

View File

@ -188,11 +188,11 @@ number: "1" | "2"
@pytest.fixture(scope="session")
def zephyr_lora_files():
"""Download zephyr LoRA files once per test session."""
def qwen3_lora_files():
"""Download Qwen3 LoRA files once per test session."""
from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
return snapshot_download(repo_id="charent/self_cognition_Alice")
@pytest.fixture(scope="session")

View File

@ -16,7 +16,7 @@ from vllm.version import __version__ as VLLM_VERSION
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")

View File

@ -19,6 +19,14 @@ from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@pytest.fixture(scope="module")
def zephyr_lora_files():
"""Download zephyr LoRA files once per test session."""
from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
@pytest.fixture(scope="module")
def server(zephyr_lora_files): # noqa: F811
args = [

View File

@ -8,7 +8,7 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")
@ -20,7 +20,6 @@ def server():
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--max-num-seqs",
"128",
"--enable-chunked-prefill",

View File

@ -13,9 +13,8 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
MODEL_NAME = "Qwen/Qwen3-0.6B"
BADREQUEST_CASES = [
(
@ -33,11 +32,11 @@ BADREQUEST_CASES = [
@pytest.fixture(scope="module", params=[True])
def server_with_lora_modules_json(request, zephyr_lora_files):
def server_with_lora_modules_json(request, qwen3_lora_files):
# Define the json format LoRA module configurations
lora_module_1 = {
"name": "zephyr-lora",
"path": zephyr_lora_files,
"name": "qwen3-lora",
"path": qwen3_lora_files,
"base_model_name": MODEL_NAME,
}
@ -74,7 +73,7 @@ async def client(server_with_lora_modules_json):
@pytest.mark.asyncio
async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files):
async def test_static_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
models = await client.models.list()
models = models.data
served_model = models[0]
@ -82,17 +81,17 @@ async def test_static_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert served_model.parent is None
assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models)
assert all(lora_model.root == qwen3_lora_files for lora_model in lora_models)
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[0].id == "qwen3-lora"
@pytest.mark.asyncio
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_files):
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, qwen3_lora_files):
response = await client.post(
"load_lora_adapter",
cast_to=str,
body={"lora_name": "zephyr-lora-3", "lora_path": zephyr_lora_files},
body={"lora_name": "qwen3-lora-3", "lora_path": qwen3_lora_files},
)
# Ensure adapter loads before querying /models
assert "success" in response
@ -100,9 +99,9 @@ async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI, zephyr_lora_file
models = await client.models.list()
models = models.data
dynamic_lora_model = models[-1]
assert dynamic_lora_model.root == zephyr_lora_files
assert dynamic_lora_model.root == qwen3_lora_files
assert dynamic_lora_model.parent == MODEL_NAME
assert dynamic_lora_model.id == "zephyr-lora-3"
assert dynamic_lora_model.id == "qwen3-lora-3"
@pytest.mark.asyncio
@ -134,7 +133,7 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI, tmp_path):
async def test_dynamic_lora_badrequests(
client: openai.AsyncOpenAI,
tmp_path,
zephyr_lora_files,
qwen3_lora_files,
test_name: str,
config_change: dict,
expected_error: str,
@ -143,7 +142,7 @@ async def test_dynamic_lora_badrequests(
test_dir = tmp_path / test_name
# Copy adapter files
shutil.copytree(zephyr_lora_files, test_dir)
shutil.copytree(qwen3_lora_files, test_dir)
# Load and modify configuration
config_path = test_dir / "adapter_config.json"
@ -167,7 +166,7 @@ async def test_dynamic_lora_badrequests(
@pytest.mark.asyncio
async def test_multiple_lora_adapters(
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files
client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
):
"""Validate that many loras can be dynamically registered and inferenced
with concurrently"""
@ -178,7 +177,7 @@ async def test_multiple_lora_adapters(
await client.post(
"load_lora_adapter",
cast_to=str,
body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)},
body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
)
for _ in range(3):
await client.completions.create(
@ -199,7 +198,7 @@ async def test_multiple_lora_adapters(
@pytest.mark.asyncio
async def test_loading_invalid_adapters_does_not_break_others(
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files
client: openai.AsyncOpenAI, tmp_path, qwen3_lora_files
):
invalid_files = tmp_path / "invalid_files"
invalid_files.mkdir()
@ -215,7 +214,7 @@ async def test_loading_invalid_adapters_does_not_break_others(
while not stop_good_requests_event.is_set():
try:
batch = await client.completions.create(
model="zephyr-lora",
model="qwen3-lora",
prompt=["Hello there", "Foo bar bazz buzz"],
max_tokens=5,
)
@ -254,7 +253,7 @@ async def test_loading_invalid_adapters_does_not_break_others(
await client.post(
"load_lora_adapter",
cast_to=str,
body={"lora_name": "valid", "lora_path": zephyr_lora_files},
body={"lora_name": "valid", "lora_path": qwen3_lora_files},
)
await client.completions.create(
model="valid",
@ -267,7 +266,7 @@ async def test_loading_invalid_adapters_does_not_break_others(
async def test_beam_search_with_lora_adapters(
client: openai.AsyncOpenAI,
tmp_path,
zephyr_lora_files,
qwen3_lora_files,
):
"""Validate that async beam search can be used with lora."""
@ -275,7 +274,7 @@ async def test_beam_search_with_lora_adapters(
await client.post(
"load_lora_adapter",
cast_to=str,
body={"lora_name": adapter_name, "lora_path": str(zephyr_lora_files)},
body={"lora_name": adapter_name, "lora_path": str(qwen3_lora_files)},
)
for _ in range(3):
await client.completions.create(

View File

@ -8,13 +8,13 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
@pytest.fixture(scope="module")
def server(zephyr_lora_files):
def server(qwen3_lora_files):
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -25,7 +25,7 @@ def server(zephyr_lora_files):
# lora config below
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"qwen3-lora={qwen3_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
@ -45,12 +45,12 @@ async def client(server):
@pytest.mark.asyncio
async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
async def test_check_models(client: openai.AsyncOpenAI, qwen3_lora_files):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert served_model.root == MODEL_NAME
assert all(lora_model.root == zephyr_lora_files for lora_model in lora_models)
assert lora_models[0].id == "zephyr-lora"
assert all(lora_model.root == qwen3_lora_files for lora_model in lora_models)
assert lora_models[0].id == "qwen3-lora"

View File

@ -8,7 +8,7 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")
@ -110,8 +110,9 @@ async def test_single_completion(client: openai.AsyncOpenAI):
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
# When using Qwen3-0.6B, prompt tokens=[9707, 11, 847, 829, 374]
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11
completion_tokens=5, prompt_tokens=5, total_tokens=10
)
# test using token IDs

View File

@ -11,11 +11,11 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")
def default_server_args(zephyr_lora_files):
def default_server_args(qwen3_lora_files):
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
@ -28,7 +28,7 @@ def default_server_args(zephyr_lora_files):
# lora config
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"qwen3-lora={qwen3_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",

View File

@ -10,7 +10,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")

View File

@ -10,7 +10,7 @@ from vllm.version import __version__ as VLLM_VERSION
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module")

View File

@ -9,7 +9,6 @@ import pytest_asyncio
from ...utils import RemoteOpenAIServer
# Model name constants used across tests
MODEL_NAME_ZEPHYR = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
LORA_ADAPTER_NAME_SMOLLM = "jekunz/smollm-135m-lora-fineweb-faroese"

View File

@ -154,23 +154,6 @@ def dummy_model_gate_up() -> nn.Module:
return model
@pytest.fixture(scope="session")
def llama_2_7b_base_huggingface_id():
# used as a base model for testing with sql lora adapter
return "meta-llama/Llama-2-7b-hf"
@pytest.fixture(scope="session")
def sql_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "yard1/llama-2-7b-sql-lora-test"
@pytest.fixture(scope="session")
def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test
@ -256,8 +239,14 @@ def qwen3_lora_files():
@pytest.fixture(scope="session")
def llama32_lora_files():
return snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")
def llama32_lora_huggingface_id():
# huggingface repo id is used to test lora runtime downloading.
return "jeeejeee/llama32-3b-text2sql-spider"
@pytest.fixture(scope="session")
def llama32_lora_files(llama32_lora_huggingface_id):
return snapshot_download(repo_id=llama32_lora_huggingface_id)
@pytest.fixture

View File

@ -26,8 +26,7 @@ def test_load_checkpoints(
chatglm3_lora_files,
):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
@ -47,8 +46,7 @@ def test_load_checkpoints(
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
model_vocab_size=64000,
)
elif lora_name == "baichuan7B-zero":
# Test that the target_modules contain prefix
@ -63,8 +61,7 @@ def test_load_checkpoints(
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
model_vocab_size=64000,
)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
@ -78,8 +75,7 @@ def test_load_checkpoints(
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
model_vocab_size=64000,
)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
@ -95,15 +91,13 @@ def test_load_checkpoints(
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
model_vocab_size=64000,
)
def test_lora_weights_mapping(baichuan_lora_files):
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
expected_lora_lst: list[str] = []
for module in BAICHUAN_LORA_MODULES:
if module in packed_modules_mapping:
@ -128,8 +122,7 @@ def test_lora_weights_mapping(baichuan_lora_files):
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
model_vocab_size=64000,
weights_mapper=hf_to_vllm_mapper,
)
for name in lora_model.loras:

View File

@ -6,10 +6,10 @@ import pytest
from vllm.lora.models import LoRAModel
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
# Provide absolute path and huggingface lora ids
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
lora_fixture_name = ["llama32_lora_files", "llama32_lora_huggingface_id"]
LLAMA_LORA_MODULES = [
"qkv_proj",
"o_proj",
@ -23,9 +23,8 @@ LLAMA_LORA_MODULES = [
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_name = request.getfixturevalue(lora_fixture_name)
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
embedding_modules = LlamaForCausalLM.embedding_modules
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
packed_modules_mapping = Qwen3ForCausalLM.packed_modules_mapping
expected_lora_lst: list[str] = []
for module in LLAMA_LORA_MODULES:
if module in packed_modules_mapping:
@ -43,8 +42,6 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
peft_helper=peft_helper,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules,
)
# Assertions to ensure the model is loaded correctly

View File

@ -34,7 +34,6 @@ EMBEDDING_MODULES = {
"lm_head": "output_embeddings",
}
EMBEDDING_PADDING_MODULES = ["lm_head"]
DEVICES = (
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@ -46,24 +45,22 @@ DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors"))
def test_from_lora_tensors(qwen3_lora_files, device):
tensors = load_file(os.path.join(qwen3_lora_files, "adapter_model.safetensors"))
peft_helper = PEFTHelper.from_local_dir(
sql_lora_files, max_position_embeddings=4096
qwen3_lora_files, max_position_embeddings=4096
)
lora_model = LoRAModel.from_lora_tensors(
1,
tensors,
peft_helper=peft_helper,
device=device,
embedding_modules=EMBEDDING_MODULES,
embedding_padding_modules=EMBEDDING_PADDING_MODULES,
)
for module_name, lora in lora_model.loras.items():
assert lora.module_name == module_name
assert lora.rank == 8
assert lora.lora_alpha == 16
assert lora.lora_alpha == 32
assert lora.lora_a is not None
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
@ -430,7 +427,7 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = LRUCacheWorkerLoRAManager(
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
vllm_config, device, EMBEDDING_MODULES
)
worker_adapter_manager.max_num_seqs = 4
@ -533,9 +530,7 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_adapter_manager = WorkerLoRAManager(
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
)
worker_adapter_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_adapter_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)

View File

@ -40,7 +40,10 @@ EXPECTED_BASE_MODEL_OUTPUT = [
def generate_and_test(
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
llm: vllm.LLM,
lora_path: str,
lora_id: list[int | None] | int | None,
compare_lower: bool = False,
) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
@ -74,12 +77,18 @@ def generate_and_test(
for i in range(len(EXPECTED_LORA_OUTPUT)):
req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
generated_text = generated_texts[i]
expected_output = (
EXPECTED_LORA_OUTPUT[i]
if req_lora_id is not None
else EXPECTED_BASE_MODEL_OUTPUT[i]
)
assert generated_texts[i].startswith(expected_output)
if compare_lower:
generated_text = generated_text.lower()
expected_output = expected_output.lower()
assert generated_text.startswith(expected_output)
def test_olmoe_lora(olmoe_lora_files):
@ -146,6 +155,9 @@ def test_olmoe_lora_tp4(olmoe_lora_files, fully_sharded_loras):
tensor_parallel_size=4,
fully_sharded_loras=fully_sharded_loras,
)
generate_and_test(llm, olmoe_lora_files, lora_id=1)
generate_and_test(llm, olmoe_lora_files, lora_id=2)
generate_and_test(
llm, olmoe_lora_files, lora_id=1, compare_lower=fully_sharded_loras
)
generate_and_test(
llm, olmoe_lora_files, lora_id=2, compare_lower=fully_sharded_loras
)

View File

@ -25,31 +25,33 @@ ERROR_CASES = [
]
def test_peft_helper_pass(sql_lora_files, tmp_path):
def test_peft_helper_pass(llama32_lora_files, tmp_path):
peft_helper = PEFTHelper.from_local_dir(
sql_lora_files, max_position_embeddings=4096
llama32_lora_files, max_position_embeddings=4096
)
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
peft_helper.validate_legal(lora_config)
assert peft_helper.r == 8
assert peft_helper.lora_alpha == 16
assert peft_helper.target_modules == [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
assert peft_helper.lora_alpha == 32
target_modules = sorted(peft_helper.target_modules)
assert target_modules == [
"down_proj",
"embed_tokens",
"gate_proj",
"k_proj",
"lm_head",
"o_proj",
"q_proj",
"up_proj",
"v_proj",
]
assert peft_helper.vllm_max_position_embeddings == 4096
# test RSLoRA
rslora_config = dict(use_rslora=True)
test_dir = tmp_path / "test_rslora"
shutil.copytree(sql_lora_files, test_dir)
shutil.copytree(llama32_lora_files, test_dir)
# Load and modify configuration
config_path = test_dir / "adapter_config.json"
@ -70,14 +72,14 @@ def test_peft_helper_pass(sql_lora_files, tmp_path):
@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
def test_peft_helper_error(
sql_lora_files,
llama32_lora_files,
tmp_path,
test_name: str,
config_change: dict,
expected_error: str,
):
test_dir = tmp_path / test_name
shutil.copytree(sql_lora_files, test_dir)
shutil.copytree(llama32_lora_files, test_dir)
# Load and modify configuration
config_path = test_dir / "adapter_config.json"

View File

@ -8,8 +8,8 @@ from huggingface_hub import snapshot_download
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
LORA_NAME = "typeof/zephyr-7b-beta-lora"
MODEL_NAME = "Qwen/Qwen3-0.6B"
LORA_NAME = "charent/self_cognition_Alice"
PA_NAME = "swapnilbp/llama_tweet_ptune"
@ -21,7 +21,7 @@ def adapter_cache(request, tmpdir_factory):
@pytest.fixture(scope="module")
def zephyr_lora_files():
def qwen3_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@ -31,9 +31,9 @@ def pa_files():
@pytest.mark.asyncio
async def test_filesystem_resolver(adapter_cache, zephyr_lora_files):
async def test_filesystem_resolver(adapter_cache, qwen3_lora_files):
model_files = adapter_cache / LORA_NAME
shutil.copytree(zephyr_lora_files, model_files)
shutil.copytree(qwen3_lora_files, model_files)
fs_resolver = FilesystemResolver(adapter_cache)
assert fs_resolver is not None

View File

@ -103,7 +103,6 @@ class LoRAModel:
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
@ -112,9 +111,7 @@ class LoRAModel:
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
target_embedding_padding: int | None = None,
embedding_modules: dict[str, str] | None = None,
embedding_padding_modules: list[str] | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
@ -132,22 +129,21 @@ class LoRAModel:
)
if is_lora_a:
if (
"lora_embedding_A" in tensor_name
and model_vocab_size is not None
and model_vocab_size != tensor.shape[1]
):
raise RuntimeError(
f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
f" with the base model's vocabulary size({model_vocab_size})."
)
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
assert embedding_padding_modules is not None
if (
any(name in module_name for name in embedding_padding_modules)
and target_embedding_padding is not None
):
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[0]
addition = target_embedding_padding - lora_b.shape[0]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, 0, 0, addition)
)
if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
@ -163,9 +159,7 @@ class LoRAModel:
lora_model_id: int | None = None,
device: str = "cuda",
dtype: torch.dtype | None = None,
target_embedding_padding: int | None = None,
embedding_modules: dict[str, str] | None = None,
embedding_padding_modules: list[str] | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
tensorizer_config_dict: dict | None = None,
) -> "LoRAModel":
@ -287,9 +281,7 @@ class LoRAModel:
peft_helper=peft_helper,
device=device,
dtype=dtype,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
model_vocab_size=model_vocab_size,
weights_mapper=weights_mapper,
)

View File

@ -34,12 +34,10 @@ class WorkerLoRAManager:
vllm_config: VllmConfig,
device: torch.device,
embedding_modules: dict[str, str],
embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_batched_tokens = (
@ -121,9 +119,7 @@ class WorkerLoRAManager:
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
model_vocab_size=self.vocab_size,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper,
)

View File

@ -482,7 +482,6 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,

View File

@ -419,7 +419,6 @@ class BambaForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(

View File

@ -457,7 +457,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -450,7 +450,6 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -510,7 +510,6 @@ class FalconH1ForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(

View File

@ -400,7 +400,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -497,7 +497,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -601,7 +601,6 @@ class GraniteMoeHybridForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(

View File

@ -263,7 +263,6 @@ class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -347,7 +347,6 @@ class SupportsLoRA(Protocol):
# The `embedding_module` and `embedding_padding_modules`
# are empty by default.
embedding_modules: ClassVar[dict[str, str]] = {}
embedding_padding_modules: ClassVar[list[str]] = []
packed_modules_mapping: dict[str, list[str]] = {}
@ -359,7 +358,6 @@ class _SupportsLoRAType(Protocol):
packed_modules_mapping: dict[str, list[str]]
embedding_modules: dict[str, str]
embedding_padding_modules: list[str]
@overload
@ -379,7 +377,6 @@ def supports_lora(
lora_attrs = (
"packed_modules_mapping",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs if not hasattr(model, attr))

View File

@ -480,7 +480,6 @@ class JambaForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config

View File

@ -422,7 +422,6 @@ class Lfm2ForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(

View File

@ -602,7 +602,6 @@ class Lfm2MoeForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(

View File

@ -528,7 +528,6 @@ class LlamaForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints

View File

@ -568,7 +568,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -305,7 +305,6 @@ class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -1741,5 +1741,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
# so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)

View File

@ -496,7 +496,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -439,7 +439,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -713,7 +713,6 @@ class NemotronHForCausalLM(
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_dtype_from_config(

View File

@ -387,7 +387,6 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints

View File

@ -617,7 +617,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -426,7 +426,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -93,7 +93,6 @@ ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
class Base(nn.Module, VllmModel, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"] # TODO transformers will have a util to get it
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={

View File

@ -43,7 +43,6 @@ class LoRAModelRunnerMixin:
vllm_config,
device,
model.embedding_modules,
model.embedding_padding_modules,
)
return self.lora_manager.create_lora_manager(model)