mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-17 08:34:28 +08:00
[Bugfix] Fix inappropriate content of model_name tag in Prometheus metrics (#3937)
This commit is contained in:
parent
021b1a2ab7
commit
4302987069
@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
@ -76,6 +78,34 @@ def test_metric_counter_generation_tokens(
|
||||
f"metric: {metric_count!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize(
|
||||
"served_model_name",
|
||||
[None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]])
|
||||
def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
|
||||
served_model_name: List[str]) -> None:
|
||||
vllm_model = vllm_runner(model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.3,
|
||||
served_model_name=served_model_name)
|
||||
stat_logger = vllm_model.model.llm_engine.stat_logger
|
||||
metrics_tag_content = stat_logger.labels["model_name"]
|
||||
|
||||
del vllm_model
|
||||
|
||||
if served_model_name is None or served_model_name == []:
|
||||
assert metrics_tag_content == model, (
|
||||
f"Metrics tag model_name is wrong! expect: {model!r}\n"
|
||||
f"actual: {metrics_tag_content!r}")
|
||||
else:
|
||||
assert metrics_tag_content == served_model_name[0], (
|
||||
f"Metrics tag model_name is wrong! expect: "
|
||||
f"{served_model_name[0]!r}\n"
|
||||
f"actual: {metrics_tag_content!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
|
||||
@ -31,6 +31,8 @@ class ModelConfig:
|
||||
|
||||
Args:
|
||||
model: Name or path of the huggingface model to use.
|
||||
It is also used as the content for `model_name` tag in metrics
|
||||
output when `served_model_name` is not specified.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, and "slow" will always use the slow tokenizer.
|
||||
@ -69,6 +71,10 @@ class ModelConfig:
|
||||
to eager mode
|
||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||
detokenizer.
|
||||
served_model_name: The model name used in metrics tag `model_name`,
|
||||
matches the model name exposed via the APIs. If multiple model
|
||||
names provided, the first name will be used. If not specified,
|
||||
the model name will be the same as `model`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -90,6 +96,7 @@ class ModelConfig:
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 5,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -117,6 +124,8 @@ class ModelConfig:
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_text_config,
|
||||
max_model_len)
|
||||
self.served_model_name = get_served_model_name(model,
|
||||
served_model_name)
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
@ -1150,6 +1159,22 @@ def _get_and_verify_max_len(
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
def get_served_model_name(model: str,
|
||||
served_model_name: Optional[Union[str, List[str]]]):
|
||||
"""
|
||||
If the input is a non-empty list, the first model_name in
|
||||
`served_model_name` is taken.
|
||||
If the input is a non-empty string, it is used directly.
|
||||
For cases where the input is either an empty string or an
|
||||
empty list, the fallback is to use `self.model`.
|
||||
"""
|
||||
if not served_model_name:
|
||||
return model
|
||||
if isinstance(served_model_name, list):
|
||||
return served_model_name[0]
|
||||
return served_model_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
@ -21,6 +21,7 @@ def nullable_str(val: str):
|
||||
class EngineArgs:
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str
|
||||
served_model_name: Optional[Union[List[str]]] = None
|
||||
tokenizer: Optional[str] = None
|
||||
skip_tokenizer_init: bool = False
|
||||
tokenizer_mode: str = 'auto'
|
||||
@ -489,6 +490,21 @@ class EngineArgs:
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary.')
|
||||
|
||||
parser.add_argument(
|
||||
"--served-model-name",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name(s) used in the API. If multiple "
|
||||
"names are provided, the server will respond to any "
|
||||
"of the provided names. The model name in the model "
|
||||
"field of a response will be the first name in this "
|
||||
"list. If not specified, the model name will be the "
|
||||
"same as the `--model` argument. Noted that this name(s)"
|
||||
"will also be used in `model_name` tag content of "
|
||||
"prometheus metrics, if multiple names provided, metrics"
|
||||
"tag will take the first one.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -508,7 +524,7 @@ class EngineArgs:
|
||||
self.quantization, self.quantization_param_path,
|
||||
self.enforce_eager, self.max_context_len_to_capture,
|
||||
self.max_seq_len_to_capture, self.max_logprobs,
|
||||
self.skip_tokenizer_init)
|
||||
self.skip_tokenizer_init, self.served_model_name)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
|
||||
@ -106,7 +106,7 @@ class LLMEngine:
|
||||
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
|
||||
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
|
||||
"quantization_param_path=%s, device_config=%s, "
|
||||
"decoding_config=%r, seed=%d)",
|
||||
"decoding_config=%r, seed=%d, served_model_name=%s)",
|
||||
vllm.__version__,
|
||||
model_config.model,
|
||||
speculative_config,
|
||||
@ -129,6 +129,7 @@ class LLMEngine:
|
||||
device_config.device,
|
||||
decoding_config,
|
||||
model_config.seed,
|
||||
model_config.served_model_name,
|
||||
)
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@ -219,7 +220,7 @@ class LLMEngine:
|
||||
if self.log_stats:
|
||||
self.stat_logger = StatLogger(
|
||||
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
|
||||
labels=dict(model_name=model_config.model),
|
||||
labels=dict(model_name=model_config.served_model_name),
|
||||
max_model_len=self.model_config.max_model_len)
|
||||
self.stat_logger.info("cache_config", self.cache_config)
|
||||
|
||||
|
||||
@ -56,16 +56,6 @@ def make_arg_parser():
|
||||
default=None,
|
||||
help="If provided, the server will require this key "
|
||||
"to be presented in the header.")
|
||||
parser.add_argument("--served-model-name",
|
||||
nargs="+",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help="The model name(s) used in the API. If multiple "
|
||||
"names are provided, the server will respond to any "
|
||||
"of the provided names. The model name in the model "
|
||||
"field of a response will be the first name in this "
|
||||
"list. If not specified, the model name will be the "
|
||||
"same as the `--model` argument.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=nullable_str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user