[Bugfix] [CI] Fix Tensorizer LoRA test (#20760)

Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
Sanger Steel 2025-07-10 15:07:06 -04:00 committed by GitHub
parent c66e38ea4c
commit 5e53c89a74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 14 additions and 19 deletions

View File

@ -4,8 +4,6 @@ import subprocess
import sys import sys
from typing import Union from typing import Union
import pytest
import vllm import vllm
from vllm import LLM from vllm import LLM
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -151,8 +149,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
generate_and_test(llm, sql_lora_files) generate_and_test(llm, sql_lora_files)
@pytest.mark.skip(reason=("Skipping this test as tensorizer is not "
"working with LoRA as of #19619"))
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
@ -189,7 +185,6 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
model_uri = tmp_path / "vllm" / model_ref / suffix / model_name model_uri = tmp_path / "vllm" / model_ref / suffix / model_name
tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri))
tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir
loaded_vllm_model = LLM(model=model_ref, loaded_vllm_model = LLM(model=model_ref,
load_format="tensorizer", load_format="tensorizer",
@ -200,16 +195,16 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files,
tensor_parallel_size=2, tensor_parallel_size=2,
max_loras=2) max_loras=2)
tensorizer_config_dict = tensorizer_config.to_serializable() tc_as_dict = tensorizer_config.to_serializable()
print("lora adapter created") print("lora adapter created")
assert do_sample(loaded_vllm_model, assert do_sample(loaded_vllm_model,
sql_lora_files, sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict, tensorizer_config_dict=tc_as_dict,
lora_id=0) == EXPECTED_NO_LORA_OUTPUT lora_id=0) == EXPECTED_NO_LORA_OUTPUT
print("lora 1") print("lora 1")
assert do_sample(loaded_vllm_model, assert do_sample(loaded_vllm_model,
sql_lora_files, sql_lora_files,
tensorizer_config_dict=tensorizer_config_dict, tensorizer_config_dict=tc_as_dict,
lora_id=1) == EXPECTED_LORA_OUTPUT lora_id=1) == EXPECTED_LORA_OUTPUT

View File

@ -102,7 +102,7 @@ class PEFTHelper:
tensorizer_config = TensorizerConfig(**tensorizer_config_dict) tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
from tensorizer.stream_io import open_stream from tensorizer.stream_io import open_stream
lora_config_path = os.path.join(tensorizer_config.lora_dir, lora_config_path = os.path.join(tensorizer_config.tensorizer_dir,
"adapter_config.json") "adapter_config.json")
with open_stream(lora_config_path, with open_stream(lora_config_path,
mode="rb", mode="rb",
@ -110,7 +110,7 @@ class PEFTHelper:
config = json.load(f) config = json.load(f)
logger.info("Successfully deserialized LoRA config from %s", logger.info("Successfully deserialized LoRA config from %s",
tensorizer_config.lora_dir) tensorizer_config.tensorizer_dir)
else: else:
with open(lora_config_path) as f: with open(lora_config_path) as f:

View File

@ -222,17 +222,17 @@ class TensorizerConfig(MutableMapping):
self._is_sharded = isinstance(self.tensorizer_uri, str) \ self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None and re.search(r'%0\dd', self.tensorizer_uri) is not None
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise.")
if self.tensorizer_dir and self.tensorizer_uri: if self.tensorizer_dir and self.tensorizer_uri:
logger.warning_once( logger.warning_once(
"Provided both tensorizer_dir and tensorizer_uri. " "Provided both tensorizer_dir and tensorizer_uri. "
"Inferring tensorizer_dir from tensorizer_uri as the " "Inferring tensorizer_dir from tensorizer_uri as the "
"latter takes precedence.") "latter takes precedence.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise.")
if not self.tensorizer_uri: if not self.tensorizer_uri:
if self.lora_dir: if self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
@ -695,7 +695,7 @@ def tensorize_lora_adapter(lora_path: str,
needed to load a LoRA adapter are a safetensors-format file called needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json. adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the tensorizer_config.lora_dir Serializes the files in the tensorizer_config.tensorizer_dir
""" """
import safetensors import safetensors
@ -725,13 +725,13 @@ def tensorize_lora_adapter(lora_path: str,
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", with open_stream(f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
mode="wb+", mode="wb+",
**tensorizer_args.stream_kwargs) as f: **tensorizer_args.stream_kwargs) as f:
f.write(json.dumps(config).encode("utf-8")) f.write(json.dumps(config).encode("utf-8"))
lora_uri = (f"{tensorizer_config.lora_dir}" lora_uri = (f"{tensorizer_config.tensorizer_dir}"
f"/adapter_model.tensors") f"/adapter_model.tensors")
with open_stream(lora_uri, mode="wb+", with open_stream(lora_uri, mode="wb+",
**tensorizer_args.stream_kwargs) as f: **tensorizer_args.stream_kwargs) as f:
@ -740,4 +740,4 @@ def tensorize_lora_adapter(lora_path: str,
serializer.close() serializer.close()
logger.info("Successfully serialized LoRA files to %s", logger.info("Successfully serialized LoRA files to %s",
str(tensorizer_config.lora_dir)) str(tensorizer_config.tensorizer_dir))