[Bugfix] Fix inappropriate content of model_name tag in Prometheus metrics (#3937)

This commit is contained in:
DearPlanet 2024-05-05 06:39:34 +08:00 committed by GitHub
parent 021b1a2ab7
commit 4302987069
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 76 additions and 14 deletions

View File

@ -1,3 +1,5 @@
from typing import List
import pytest
from prometheus_client import REGISTRY
@ -76,6 +78,34 @@ def test_metric_counter_generation_tokens(
f"metric: {metric_count!r}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize(
"served_model_name",
[None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]])
def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
served_model_name: List[str]) -> None:
vllm_model = vllm_runner(model,
dtype=dtype,
disable_log_stats=False,
gpu_memory_utilization=0.3,
served_model_name=served_model_name)
stat_logger = vllm_model.model.llm_engine.stat_logger
metrics_tag_content = stat_logger.labels["model_name"]
del vllm_model
if served_model_name is None or served_model_name == []:
assert metrics_tag_content == model, (
f"Metrics tag model_name is wrong! expect: {model!r}\n"
f"actual: {metrics_tag_content!r}")
else:
assert metrics_tag_content == served_model_name[0], (
f"Metrics tag model_name is wrong! expect: "
f"{served_model_name[0]!r}\n"
f"actual: {metrics_tag_content!r}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [4])

View File

@ -31,6 +31,8 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
@ -69,6 +71,10 @@ class ModelConfig:
to eager mode
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
"""
def __init__(
@ -90,6 +96,7 @@ class ModelConfig:
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -117,6 +124,8 @@ class ModelConfig:
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
max_model_len)
self.served_model_name = get_served_model_name(model,
served_model_name)
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_quantization()
@ -1150,6 +1159,22 @@ def _get_and_verify_max_len(
return int(max_model_len)
def get_served_model_name(model: str,
served_model_name: Optional[Union[str, List[str]]]):
"""
If the input is a non-empty list, the first model_name in
`served_model_name` is taken.
If the input is a non-empty string, it is used directly.
For cases where the input is either an empty string or an
empty list, the fallback is to use `self.model`.
"""
if not served_model_name:
return model
if isinstance(served_model_name, list):
return served_model_name[0]
return served_model_name
@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""

View File

@ -1,7 +1,7 @@
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
@ -21,6 +21,7 @@ def nullable_str(val: str):
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
served_model_name: Optional[Union[List[str]]] = None
tokenizer: Optional[str] = None
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
@ -489,6 +490,21 @@ class EngineArgs:
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
"--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument. Noted that this name(s)"
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics"
"tag will take the first one.")
return parser
@classmethod
@ -508,7 +524,7 @@ class EngineArgs:
self.quantization, self.quantization_param_path,
self.enforce_eager, self.max_context_len_to_capture,
self.max_seq_len_to_capture, self.max_logprobs,
self.skip_tokenizer_init)
self.skip_tokenizer_init, self.served_model_name)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,

View File

@ -106,7 +106,7 @@ class LLMEngine:
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d)",
"decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__,
model_config.model,
speculative_config,
@ -129,6 +129,7 @@ class LLMEngine:
device_config.device,
decoding_config,
model_config.seed,
model_config.served_model_name,
)
# TODO(woosuk): Print more configs in debug mode.
@ -219,7 +220,7 @@ class LLMEngine:
if self.log_stats:
self.stat_logger = StatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.model),
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len)
self.stat_logger.info("cache_config", self.cache_config)

View File

@ -56,16 +56,6 @@ def make_arg_parser():
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument("--served-model-name",
nargs="+",
type=nullable_str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the `--model` argument.")
parser.add_argument(
"--lora-modules",
type=nullable_str,