Usage Stats Collection (#2852)

This commit is contained in:
yhu422 2024-03-28 22:16:12 -07:00 committed by GitHub
parent 7bc94a0fdd
commit d8658c8cc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 362 additions and 24 deletions

View File

@ -53,6 +53,8 @@ steps:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
{% endif %}
env:
- name: VLLM_USAGE_SOURCE
value: ci-test
- name: HF_TOKEN
valueFrom:
secretKeyRef:

View File

@ -132,5 +132,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm
ENV VLLM_USAGE_SOURCE production-docker-image
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
#################### OPENAI API SERVER ####################

View File

@ -73,6 +73,7 @@ Documentation
serving/deploying_with_docker
serving/distributed_serving
serving/metrics
serving/usage_stats
serving/integrations
.. toctree::

View File

@ -0,0 +1,57 @@
# Usage Stats Collection
vLLM collects anonymous usage data by default to help the engineering team better understand which hardware and model configurations are widely used. This data allows them to prioritize their efforts on the most common workloads. The collected data is transparent, does not contain any sensitive information, and will be publicly released for the community's benefit.
## What data is collected?
You can see the up to date list of data collected by vLLM in the [usage_lib.py](https://github.com/vllm-project/vllm/blob/main/vllm/usage/usage_lib.py).
Here is an example as of v0.4.0:
```json
{
"uuid": "fbe880e9-084d-4cab-a395-8984c50f1109",
"provider": "GCP",
"num_cpu": 24,
"cpu_type": "Intel(R) Xeon(R) CPU @ 2.20GHz",
"cpu_family_model_stepping": "6,85,7",
"total_memory": 101261135872,
"architecture": "x86_64",
"platform": "Linux-5.10.0-28-cloud-amd64-x86_64-with-glibc2.31",
"gpu_count": 2,
"gpu_type": "NVIDIA L4",
"gpu_memory_per_device": 23580639232,
"model_architecture": "OPTForCausalLM",
"vllm_version": "0.3.2+cu123",
"context": "LLM_CLASS",
"log_time": 1711663373492490000,
"source": "production",
"dtype": "torch.float16",
"tensor_parallel_size": 1,
"block_size": 16,
"gpu_memory_utilization": 0.9,
"quantization": null,
"kv_cache_dtype": "auto",
"enable_lora": false,
"enable_prefix_caching": false,
"enforce_eager": false,
"disable_custom_all_reduce": true
}
```
You can preview the collected data by running the following command:
```bash
tail ~/.config/vllm/usage_stats.json
```
## Opt-out of Usage Stats Collection
You can opt-out of usage stats collection by setting the VLLM_NO_USAGE_STATS or DO_NOT_TRACK environment variable, or by creating a ~/.config/vllm/do_not_track file:
```bash
# Any of the following methods can disable usage stats collection
export VLLM_NO_USAGE_STATS=1
export DO_NOT_TRACK=1
mkdir -p ~/.config/vllm && touch ~/.config/vllm/do_not_track
```

View File

@ -7,3 +7,6 @@ fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
requests
psutil
py-cpuinfo

View File

@ -2,6 +2,8 @@ cmake>=3.21
ninja # For faster builds.
typing-extensions>=4.8.0
starlette
requests
py-cpuinfo
psutil
ray >= 2.9
sentencepiece # Required for LLaMA tokenizer.

View File

@ -5,6 +5,9 @@ ray >= 2.9
sentencepiece # Required for LLaMA tokenizer.
numpy
torch == 2.1.2
requests
psutil
py-cpuinfo
transformers >= 4.39.1 # Required for StarCoder2 & Llava.
xformers == 0.0.23.post1 # Required for CUDA 12.1.
fastapi

View File

@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
@ -319,9 +320,12 @@ class AsyncLLMEngine:
self._errored_with: Optional[BaseException] = None
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
@ -341,14 +345,17 @@ class AsyncLLMEngine:
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
engine = cls(
parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
@property

View File

@ -13,6 +13,7 @@ from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.model_loader import get_architecture_class_name
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
@ -21,6 +22,8 @@ from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
logger = init_logger(__name__)
@ -53,6 +56,7 @@ class LLMEngine:
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
"""
def __init__(
@ -66,6 +70,7 @@ class LLMEngine:
vision_language_config: Optional["VisionLanguageConfig"],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> None:
logger.info(
f"Initializing an LLM engine (v{vllm.__version__}) with config: "
@ -108,6 +113,39 @@ class LLMEngine:
device_config, lora_config,
vision_language_config)
# If usage stat is enabled, collect relevant info.
if is_usage_stats_enabled():
usage_message.report_usage(
get_architecture_class_name(model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
# Quantization
"quantization":
model_config.quantization,
"kv_cache_dtype":
cache_config.cache_dtype,
# Feature flags
"enable_lora":
bool(lora_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
})
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
@ -125,7 +163,11 @@ class LLMEngine:
self.stat_logger.info("cache_config", self.cache_config)
@classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
@ -147,9 +189,12 @@ class LLMEngine:
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(*engine_configs,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats)
engine = cls(
*engine_configs,
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
)
return engine
def __reduce__(self):

View File

@ -18,6 +18,7 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
@ -100,9 +101,9 @@ if __name__ == "__main__":
help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.API_SERVER)
app.root_path = args.root_path
uvicorn.run(app,

View File

@ -10,6 +10,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
@ -108,7 +109,8 @@ class LLM:
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.llm_engine = LLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS)
self.request_counter = Counter()
def get_tokenizer(

View File

@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -151,9 +152,9 @@ if __name__ == "__main__":
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role,
args.lora_modules,

View File

@ -1,6 +1,6 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Type
from typing import Tuple, Type
import torch
import torch.nn as nn
@ -25,7 +25,8 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def _get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
@ -36,17 +37,21 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_architecture_class_name(model_config: ModelConfig) -> str:
return _get_model_architecture(model_config)[1]
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
vision_language_config = kwargs.get("vision_language_config", None)
model_class = _get_model_architecture(model_config)
model_class = _get_model_architecture(model_config)[0]
# Get the (maybe quantized) linear method.
linear_method = None

0
vllm/usage/__init__.py Normal file
View File

207
vllm/usage/usage_lib.py Normal file
View File

@ -0,0 +1,207 @@
import datetime
import json
import logging
import os
import platform
import time
from enum import Enum
from pathlib import Path
from threading import Thread
from typing import Dict, Optional
from uuid import uuid4
import cpuinfo
import pkg_resources
import psutil
import requests
import torch
_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config"))
_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json")
_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home,
"vllm/do_not_track")
_USAGE_STATS_ENABLED = None
_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER",
"https://stats.vllm.ai")
def is_usage_stats_enabled():
"""Determine whether or not we can send usage stats to the server.
The logic is as follows:
- By default, it should be enabled.
- Two environment variables can disable it:
- DO_NOT_TRACK=1
- VLLM_NO_USAGE_STATS=1
- A file in the home directory can disable it if it exists:
- $HOME/.config/vllm/do_not_track
"""
global _USAGE_STATS_ENABLED
if _USAGE_STATS_ENABLED is None:
do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1"
no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1"
do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH)
_USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats
or do_not_track_file)
return _USAGE_STATS_ENABLED
def _get_current_timestamp_ns() -> int:
return int(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1e9)
def _detect_cloud_provider() -> str:
# Try detecting through vendor file
vendor_files = [
"/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor",
"/sys/class/dmi/id/product_name",
"/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor"
]
# Mapping of identifiable strings to cloud providers
cloud_identifiers = {
"amazon": "AWS",
"microsoft corporation": "AZURE",
"google": "GCP",
"oraclecloud": "OCI",
}
for vendor_file in vendor_files:
path = Path(vendor_file)
if path.is_file():
file_content = path.read_text().lower()
for identifier, provider in cloud_identifiers.items():
if identifier in file_content:
return provider
# Try detecting through environment variables
env_to_cloud_provider = {
"RUNPOD_DC_ID": "RUNPOD",
}
for env_var, provider in env_to_cloud_provider.items():
if os.environ.get(env_var):
return provider
return "UNKNOWN"
class UsageContext(str, Enum):
UNKNOWN_CONTEXT = "UNKNOWN_CONTEXT"
LLM_CLASS = "LLM_CLASS"
API_SERVER = "API_SERVER"
OPENAI_API_SERVER = "OPENAI_API_SERVER"
ENGINE_CONTEXT = "ENGINE_CONTEXT"
class UsageMessage:
"""Collect platform information and send it to the usage stats server."""
def __init__(self) -> None:
# NOTE: vLLM's server _only_ support flat KV pair.
# Do not use nested fields.
self.uuid = str(uuid4())
# Environment Information
self.provider: Optional[str] = None
self.num_cpu: Optional[int] = None
self.cpu_type: Optional[str] = None
self.cpu_family_model_stepping: Optional[str] = None
self.total_memory: Optional[int] = None
self.architecture: Optional[str] = None
self.platform: Optional[str] = None
self.gpu_count: Optional[int] = None
self.gpu_type: Optional[str] = None
self.gpu_memory_per_device: Optional[int] = None
# vLLM Information
self.model_architecture: Optional[str] = None
self.vllm_version: Optional[str] = None
self.context: Optional[str] = None
# Metadata
self.log_time: Optional[int] = None
self.source: Optional[str] = None
def report_usage(self,
model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any] = None) -> None:
t = Thread(target=self._report_usage_worker,
args=(model_architecture, usage_context, extra_kvs or {}),
daemon=True)
t.start()
def _report_usage_worker(self, model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None:
self._report_usage_once(model_architecture, usage_context, extra_kvs)
self._report_continous_usage()
def _report_usage_once(self, model_architecture: str,
usage_context: UsageContext,
extra_kvs: Dict[str, any]) -> None:
# Platform information
if torch.cuda.is_available():
device_property = torch.cuda.get_device_properties(0)
self.gpu_count = torch.cuda.device_count()
self.gpu_type = device_property.name
self.gpu_memory_per_device = device_property.total_memory
self.provider = _detect_cloud_provider()
self.architecture = platform.machine()
self.platform = platform.platform()
self.total_memory = psutil.virtual_memory().total
info = cpuinfo.get_cpu_info()
self.num_cpu = info.get("count", None)
self.cpu_type = info.get("brand_raw", "")
self.cpu_family_model_stepping = ",".join([
str(info.get("family", "")),
str(info.get("model", "")),
str(info.get("stepping", ""))
])
# vLLM information
self.context = usage_context.value
self.vllm_version = pkg_resources.get_distribution("vllm").version
self.model_architecture = model_architecture
# Metadata
self.log_time = _get_current_timestamp_ns()
self.source = os.environ.get("VLLM_USAGE_SOURCE", "production")
data = vars(self)
if extra_kvs:
data.update(extra_kvs)
self._write_to_file(data)
self._send_to_server(data)
def _report_continous_usage(self):
"""Report usage every 10 minutes.
This helps us to collect more data points for uptime of vLLM usages.
This function can also help send over performance metrics over time.
"""
while True:
time.sleep(600)
data = {"uuid": self.uuid, "log_time": _get_current_timestamp_ns()}
self._write_to_file(data)
self._send_to_server(data)
def _send_to_server(self, data):
try:
requests.post(_USAGE_STATS_SERVER, json=data)
except requests.exceptions.RequestException:
# silently ignore unless we are using debug log
logging.debug("Failed to send usage data to server")
def _write_to_file(self, data):
os.makedirs(os.path.dirname(_USAGE_STATS_JSON_PATH), exist_ok=True)
Path(_USAGE_STATS_JSON_PATH).touch(exist_ok=True)
with open(_USAGE_STATS_JSON_PATH, "a") as f:
json.dump(data, f)
f.write("\n")
usage_message = UsageMessage()