[Bugfix] Fix handling of Tensorizer arguments for LoadConfig (#20643)

Signed-off-by: Sanger Steel <sangersteel@gmail.com>
This commit is contained in:
Sanger Steel 2025-07-09 11:36:37 -04:00 committed by GitHub
parent efe73d0575
commit 4ac9c33f78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 52 deletions

View File

@ -103,25 +103,6 @@ def write_keyfile(keyfile_path: str):
f.write(encryption_params.key)
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
with vllm_runner(model_ref,
load_format="tensorizer",
model_loader_extra_config=TensorizerConfig(
tensorizer_uri=tensorized_path,
num_readers=1,
s3_endpoint="object.ord1.coreweave.com",
)) as loaded_hf_model:
deserialized_outputs = loaded_hf_model.generate(
prompts, sampling_params)
# noqa: E501
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
model_ref, vllm_runner, tmp_path, model_path):

View File

@ -1003,41 +1003,27 @@ class EngineArgs:
override_attention_dtype=self.override_attention_dtype,
)
def valid_tensorizer_config_provided(self) -> bool:
"""
Checks if a parseable TensorizerConfig was passed to
self.model_loader_extra_config. It first checks if the config passed
is a dict or a TensorizerConfig object directly, and if the latter is
true (by checking that the object has TensorizerConfig's
.to_serializable() method), converts it in to a serializable dict
format
"""
if self.model_loader_extra_config:
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
try:
self.model_loader_extra_config[allowed_to_pass]
return False
except KeyError:
pass
return True
def validate_tensorizer_args(self):
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig)
for key in self.model_loader_extra_config:
if key in TensorizerConfig._fields:
self.model_loader_extra_config["tensorizer_config"][
key] = self.model_loader_extra_config[key]
def create_load_config(self) -> LoadConfig:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
if (self.load_format == "tensorizer"
and self.valid_tensorizer_config_provided()):
logger.info("Inferring Tensorizer args from %s", self.model)
self.model_loader_extra_config = {"tensorizer_dir": self.model}
else:
logger.info(
"Using Tensorizer args from --model-loader-extra-config. "
"Note that you can now simply pass the S3 directory in the "
"model tag instead of providing the JSON string.")
if self.load_format == "tensorizer":
if hasattr(self.model_loader_extra_config, "to_serializable"):
self.model_loader_extra_config = (
self.model_loader_extra_config.to_serializable())
self.model_loader_extra_config["tensorizer_config"] = {}
self.model_loader_extra_config["tensorizer_config"][
"tensorizer_dir"] = self.model
self.validate_tensorizer_args()
return LoadConfig(
load_format=self.load_format,

View File

@ -223,9 +223,11 @@ class TensorizerConfig(MutableMapping):
and re.search(r'%0\dd', self.tensorizer_uri) is not None
if self.tensorizer_dir and self.tensorizer_uri:
raise ValueError(
"Either tensorizer_dir or tensorizer_uri must be provided, "
"not both.")
logger.warning_once(
"Provided both tensorizer_dir and tensorizer_uri. "
"Inferring tensorizer_dir from tensorizer_uri as the "
"latter takes precedence.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "

View File

@ -43,7 +43,7 @@ class TensorizerLoader(BaseModelLoader):
else:
validate_config(load_config.model_loader_extra_config)
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
**load_config.model_loader_extra_config["tensorizer_config"])
def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):