[torchao] Add support for ModuleFqnToConfig using regex (#26001)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang 2025-10-09 01:32:32 -07:00 committed by GitHub
parent cf4cd6c24f
commit a83ff278d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 3 deletions

View File

@ -233,5 +233,22 @@ def test_opt_125m_float8_weight_only_safetensors_model_loading_with_params(vllm_
assert output
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.skip(
reason="since torchao nightly is only compatible with torch nightly"
"currently https://github.com/pytorch/ao/issues/2919, we'll have to skip "
"torchao tests that requires newer versions (0.14.0.dev+) for now"
)
def test_opt_125m_module_fqn_to_config_regex_model(vllm_runner):
torch._dynamo.reset()
model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev"
with vllm_runner(
model_name=model_name, dtype="bfloat16", pt_load_map_location="cuda:0"
) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=32)
assert output
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -5,6 +5,7 @@ import json
from importlib.util import find_spec
from typing import Any, Optional
import regex as re
import torch
import torch.nn.functional as F
from packaging import version
@ -192,9 +193,26 @@ class TorchAOConfig(QuantizationConfig):
module_fqn = prefix
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(module_fqn) or module_fqn_to_config.get(
"_default", None
)
c = None
if module_fqn in module_fqn_to_config:
assert not module_fqn.startswith("re:"), (
"module fqn should not start with"
"`re:`, which is used for specifying regex"
)
c = module_fqn_to_config[module_fqn]
else:
for maybe_module_fqn_pattern in module_fqn_to_config:
if not maybe_module_fqn_pattern.startswith("re:"):
continue
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
# we'll apply the config for first fully matched pattern
c = module_fqn_to_config[maybe_module_fqn_pattern]
break
else:
# fallback to use default if no module specific
# config is provided
c = module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(
c, self.skip_modules, self.is_checkpoint_torchao_serialized