Fix underscores in dict keys passed via CLI (#19030)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-06-03 15:39:24 +01:00 committed by GitHub
parent 4e68ae5e59
commit 476844d44c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 3 deletions

View File

@ -259,11 +259,18 @@ def test_dict_args(parser):
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
# Test nesting
"--hf-overrides.key2.key3",
"val2",
"--hf-overrides.key2.key4",
"val3",
# Test = sign
"--hf-overrides.key5=val4",
# Test underscore to dash conversion
"--hf_overrides.key_6",
"val5",
"--hf_overrides.key-7.key_8",
"val6",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
@ -274,6 +281,10 @@ def test_dict_args(parser):
"key4": "val3",
},
"key5": "val4",
"key_6": "val5",
"key-7": {
"key_8": "val6",
},
}

View File

@ -1456,17 +1456,24 @@ class FlexibleArgumentParser(ArgumentParser):
if '--config' in args:
args = self._pull_args_from_config(args)
def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace("_", "-")
# Everything between the first -- and the first .
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
key = pattern.sub(repl, key, count=1)
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith('-O') and arg != '-O' and len(arg) == 2:
# allow -O flag to be used without space, e.g. -O3
processed_args.append('-O')