From 090c856d7681f65143fece96f9dfd555c4b7d59b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 7 Apr 2025 20:40:58 +0200 Subject: [PATCH] [Misc] Human-readable `max-model-len` cli arg (#16181) Signed-off-by: NickLucche Signed-off-by: DarkLight1337 Co-authored-by: Cyrus Leung --- tests/engine/test_arg_utils.py | 38 +++++++++++++++++++++++++- vllm/engine/arg_utils.py | 50 ++++++++++++++++++++++++++++++++-- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 8698d124e73f..92387b46425e 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from argparse import ArgumentTypeError +from argparse import ArgumentError, ArgumentTypeError import pytest @@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option): else: args = parser.parse_args([f"--{option}", arg]) assert getattr(args, option.replace("-", "_")) == expected + + +def test_human_readable_model_len(): + # `exit_on_error` disabled to test invalid values below + parser = EngineArgs.add_cli_args( + FlexibleArgumentParser(exit_on_error=False)) + + args = parser.parse_args([]) + assert args.max_model_len is None + + args = parser.parse_args(["--max-model-len", "1024"]) + assert args.max_model_len == 1024 + + # Lower + args = parser.parse_args(["--max-model-len", "1m"]) + assert args.max_model_len == 1_000_000 + args = parser.parse_args(["--max-model-len", "10k"]) + assert args.max_model_len == 10_000 + + # Capital + args = parser.parse_args(["--max-model-len", "3K"]) + assert args.max_model_len == 1024 * 3 + args = parser.parse_args(["--max-model-len", "10M"]) + assert args.max_model_len == 2**20 * 10 + + # Decimal values + args = parser.parse_args(["--max-model-len", "10.2k"]) + assert args.max_model_len == 10200 + # ..truncated to the nearest int + args = parser.parse_args(["--max-model-len", "10.212345k"]) + assert args.max_model_len == 10212 + + # Invalid (do not allow decimals with binary multipliers) + for invalid in ["1a", "pwd", "10.24", "1.23M"]: + with pytest.raises(ArgumentError): + args = parser.parse_args(["--max-model-len", invalid]) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9ccfdf58cfd6..6d9f89faf71a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,6 +3,7 @@ import argparse import dataclasses import json +import re import threading from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, @@ -368,10 +369,14 @@ class EngineArgs: 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument('--max-model-len', - type=int, + type=human_readable_int, default=EngineArgs.max_model_len, help='Model context length. If unspecified, will ' - 'be automatically derived from the model config.') + 'be automatically derived from the model config. ' + 'Supports k/m/g/K/M/G in human-readable format.\n' + 'Examples:\n' + '- 1k → 1000\n' + '- 1K → 1024\n') parser.add_argument( '--guided-decoding-backend', type=str, @@ -1740,6 +1745,47 @@ def _warn_or_fallback(feature_name: str) -> bool: return should_exit +def human_readable_int(value): + """Parse human-readable integers like '1k', '2M', etc. + Including decimal values with decimal multipliers. + + Examples: + - '1k' -> 1,000 + - '1K' -> 1,024 + - '25.6k' -> 25,600 + """ + value = value.strip() + match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + if match: + decimal_multiplier = { + 'k': 10**3, + 'm': 10**6, + 'g': 10**9, + } + binary_multiplier = { + 'K': 2**10, + 'M': 2**20, + 'G': 2**30, + } + + number, suffix = match.groups() + if suffix in decimal_multiplier: + mult = decimal_multiplier[suffix] + return int(float(number) * mult) + elif suffix in binary_multiplier: + mult = binary_multiplier[suffix] + # Do not allow decimals with binary multipliers + try: + return int(number) * mult + except ValueError as e: + raise argparse.ArgumentTypeError("Decimals are not allowed " \ + f"with binary suffixes like {suffix}. Did you mean to use " \ + f"{number}{suffix.lower()} instead?") from e + + # Regular plain number. + return int(value) + + # These functions are used by sphinx to build the documentation def _engine_args_parser(): return EngineArgs.add_cli_args(FlexibleArgumentParser())