[Bugfix] Enable V1 usage stats (#16986)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Michael Goin 2025-04-23 20:54:00 -06:00 committed by GitHub
parent 46e678bcff
commit ed50f46641
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 75 additions and 5 deletions

View File

@ -19,6 +19,7 @@ import torch
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties
from vllm.version import __version__ as VLLM_VERSION
_config_home = envs.VLLM_CONFIG_ROOT
@ -168,10 +169,9 @@ class UsageMessage:
# Platform information
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
device_property = torch.cuda.get_device_properties(0)
self.gpu_count = torch.cuda.device_count()
self.gpu_type = device_property.name
self.gpu_memory_per_device = device_property.total_memory
self.gpu_count = cuda_device_count_stateless()
self.gpu_type, self.gpu_memory_per_device = (
cuda_get_device_properties(0, ("name", "total_memory")))
if current_platform.is_cuda():
self.cuda_runtime = torch.version.cuda
self.provider = _detect_cloud_provider()

View File

@ -38,11 +38,13 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, KeysView, Mapping)
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, cast, overload)
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
overload)
from uuid import uuid4
import cachetools
@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool:
return torch.cuda.is_initialized()
def cuda_get_device_properties(device,
names: Sequence[str],
init_cuda=False) -> tuple[Any, ...]:
"""Get specified CUDA device property values without initializing CUDA in
the current process."""
if init_cuda or cuda_is_initialized():
props = torch.cuda.get_device_properties(device)
return tuple(getattr(props, name) for name in names)
# Run in subprocess to avoid initializing CUDA as a side effect.
mp_ctx = multiprocessing.get_context("fork")
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
return executor.submit(cuda_get_device_properties, device, names,
True).result()
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
"""Make an instance method that weakly references
its associated instance and no-ops once that

View File

@ -36,6 +36,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
StatLoggerBase)
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
from vllm.v1.utils import report_usage_stats
logger = init_logger(__name__)
@ -114,6 +115,9 @@ class AsyncLLM(EngineClient):
except RuntimeError:
pass
# If usage stat is enabled, collect relevant info.
report_usage_stats(vllm_config, usage_context)
@classmethod
def from_vllm_config(
cls,

View File

@ -28,6 +28,7 @@ from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import report_usage_stats
logger = init_logger(__name__)
@ -99,6 +100,9 @@ class LLMEngine:
# for v0 compatibility
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
# If usage stat is enabled, collect relevant info.
report_usage_stats(vllm_config, usage_context)
@classmethod
def from_vllm_config(
cls,

View File

@ -12,6 +12,8 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import get_mp_context, kill_process_tree
if TYPE_CHECKING:
@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
"""Report usage statistics if enabled."""
if not is_usage_stats_enabled():
return
from vllm.model_executor.model_loader import get_architecture_class_name
usage_message.report_usage(
get_architecture_class_name(vllm_config.model_config),
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(vllm_config.model_config.dtype),
"tensor_parallel_size":
vllm_config.parallel_config.tensor_parallel_size,
"block_size":
vllm_config.cache_config.block_size,
"gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization,
# Quantization
"quantization":
vllm_config.model_config.quantization,
"kv_cache_dtype":
str(vllm_config.cache_config.cache_dtype),
# Feature flags
"enable_lora":
bool(vllm_config.lora_config),
"enable_prompt_adapter":
bool(vllm_config.prompt_adapter_config),
"enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching,
"enforce_eager":
vllm_config.model_config.enforce_eager,
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})