From 5e53c89a74ff98b7b0a9299f8854d17eb8142e9e Mon Sep 17 00:00:00 2001 From: Sanger Steel Date: Thu, 10 Jul 2025 15:07:06 -0400 Subject: [PATCH] [Bugfix] [CI] Fix Tensorizer LoRA test (#20760) Signed-off-by: Sanger Steel --- tests/lora/test_llama_tp.py | 11 +++-------- vllm/lora/peft_helper.py | 4 ++-- vllm/model_executor/model_loader/tensorizer.py | 18 +++++++++--------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 9068d3c0e5369..bebf44b6dfd7c 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -4,8 +4,6 @@ import subprocess import sys from typing import Union -import pytest - import vllm from vllm import LLM from vllm.lora.request import LoRARequest @@ -151,8 +149,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): generate_and_test(llm, sql_lora_files) -@pytest.mark.skip(reason=("Skipping this test as tensorizer is not " - "working with LoRA as of #19619")) @multi_gpu_test(num_gpus=2) @create_new_process_for_each_test() def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, @@ -189,7 +185,6 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, model_uri = tmp_path / "vllm" / model_ref / suffix / model_name tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) - tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir loaded_vllm_model = LLM(model=model_ref, load_format="tensorizer", @@ -200,16 +195,16 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, tensor_parallel_size=2, max_loras=2) - tensorizer_config_dict = tensorizer_config.to_serializable() + tc_as_dict = tensorizer_config.to_serializable() print("lora adapter created") assert do_sample(loaded_vllm_model, sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, + tensorizer_config_dict=tc_as_dict, lora_id=0) == EXPECTED_NO_LORA_OUTPUT print("lora 1") assert do_sample(loaded_vllm_model, sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, + tensorizer_config_dict=tc_as_dict, lora_id=1) == EXPECTED_LORA_OUTPUT diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index e748a4a889244..24099bf479dec 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -102,7 +102,7 @@ class PEFTHelper: tensorizer_config = TensorizerConfig(**tensorizer_config_dict) tensorizer_args = tensorizer_config._construct_tensorizer_args() from tensorizer.stream_io import open_stream - lora_config_path = os.path.join(tensorizer_config.lora_dir, + lora_config_path = os.path.join(tensorizer_config.tensorizer_dir, "adapter_config.json") with open_stream(lora_config_path, mode="rb", @@ -110,7 +110,7 @@ class PEFTHelper: config = json.load(f) logger.info("Successfully deserialized LoRA config from %s", - tensorizer_config.lora_dir) + tensorizer_config.tensorizer_dir) else: with open(lora_config_path) as f: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index d716f60e5ffa2..3d491be3156b6 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -222,17 +222,17 @@ class TensorizerConfig(MutableMapping): self._is_sharded = isinstance(self.tensorizer_uri, str) \ and re.search(r'%0\dd', self.tensorizer_uri) is not None + if self.tensorizer_dir and self.lora_dir: + raise ValueError( + "Only one of tensorizer_dir or lora_dir may be specified. " + "Use lora_dir exclusively when serializing LoRA adapters, " + "and tensorizer_dir or tensorizer_uri otherwise.") if self.tensorizer_dir and self.tensorizer_uri: logger.warning_once( "Provided both tensorizer_dir and tensorizer_uri. " "Inferring tensorizer_dir from tensorizer_uri as the " "latter takes precedence.") self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) - if self.tensorizer_dir and self.lora_dir: - raise ValueError( - "Only one of tensorizer_dir or lora_dir may be specified. " - "Use lora_dir exclusively when serializing LoRA adapters, " - "and tensorizer_dir or tensorizer_uri otherwise.") if not self.tensorizer_uri: if self.lora_dir: self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" @@ -695,7 +695,7 @@ def tensorize_lora_adapter(lora_path: str, needed to load a LoRA adapter are a safetensors-format file called adapter_model.safetensors and a json config file called adapter_config.json. - Serializes the files in the tensorizer_config.lora_dir + Serializes the files in the tensorizer_config.tensorizer_dir """ import safetensors @@ -725,13 +725,13 @@ def tensorize_lora_adapter(lora_path: str, tensorizer_args = tensorizer_config._construct_tensorizer_args() - with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", + with open_stream(f"{tensorizer_config.tensorizer_dir}/adapter_config.json", mode="wb+", **tensorizer_args.stream_kwargs) as f: f.write(json.dumps(config).encode("utf-8")) - lora_uri = (f"{tensorizer_config.lora_dir}" + lora_uri = (f"{tensorizer_config.tensorizer_dir}" f"/adapter_model.tensors") with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f: @@ -740,4 +740,4 @@ def tensorize_lora_adapter(lora_path: str, serializer.close() logger.info("Successfully serialized LoRA files to %s", - str(tensorizer_config.lora_dir)) + str(tensorizer_config.tensorizer_dir))