mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:15:01 +08:00
[Bugfix] Enable V1 usage stats (#16986)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
46e678bcff
commit
ed50f46641
@ -19,6 +19,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
_config_home = envs.VLLM_CONFIG_ROOT
|
||||
@ -168,10 +169,9 @@ class UsageMessage:
|
||||
# Platform information
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
device_property = torch.cuda.get_device_properties(0)
|
||||
self.gpu_count = torch.cuda.device_count()
|
||||
self.gpu_type = device_property.name
|
||||
self.gpu_memory_per_device = device_property.total_memory
|
||||
self.gpu_count = cuda_device_count_stateless()
|
||||
self.gpu_type, self.gpu_memory_per_device = (
|
||||
cuda_get_device_properties(0, ("name", "total_memory")))
|
||||
if current_platform.is_cuda():
|
||||
self.cuda_runtime = torch.version.cuda
|
||||
self.provider = _detect_cloud_provider()
|
||||
|
||||
@ -38,11 +38,13 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
||||
Iterable, Iterator, KeysView, Mapping)
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
from types import MappingProxyType
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||
Optional, Tuple, Type, TypeVar, Union, cast, overload)
|
||||
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
|
||||
overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import cachetools
|
||||
@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool:
|
||||
return torch.cuda.is_initialized()
|
||||
|
||||
|
||||
def cuda_get_device_properties(device,
|
||||
names: Sequence[str],
|
||||
init_cuda=False) -> tuple[Any, ...]:
|
||||
"""Get specified CUDA device property values without initializing CUDA in
|
||||
the current process."""
|
||||
if init_cuda or cuda_is_initialized():
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
return tuple(getattr(props, name) for name in names)
|
||||
|
||||
# Run in subprocess to avoid initializing CUDA as a side effect.
|
||||
mp_ctx = multiprocessing.get_context("fork")
|
||||
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
|
||||
return executor.submit(cuda_get_device_properties, device, names,
|
||||
True).result()
|
||||
|
||||
|
||||
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
||||
"""Make an instance method that weakly references
|
||||
its associated instance and no-ops once that
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
||||
StatLoggerBase)
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -114,6 +115,9 @@ class AsyncLLM(EngineClient):
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(vllm_config, usage_context)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
|
||||
@ -28,6 +28,7 @@ from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
from vllm.v1.engine.processor import Processor
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -99,6 +100,9 @@ class LLMEngine:
|
||||
# for v0 compatibility
|
||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||
|
||||
# If usage stat is enabled, collect relevant info.
|
||||
report_usage_stats(vllm_config, usage_context)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
|
||||
@ -12,6 +12,8 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
Returns the sliced target tensor.
|
||||
"""
|
||||
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||
|
||||
|
||||
def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
|
||||
"""Report usage statistics if enabled."""
|
||||
|
||||
if not is_usage_stats_enabled():
|
||||
return
|
||||
|
||||
from vllm.model_executor.model_loader import get_architecture_class_name
|
||||
|
||||
usage_message.report_usage(
|
||||
get_architecture_class_name(vllm_config.model_config),
|
||||
usage_context,
|
||||
extra_kvs={
|
||||
# Common configuration
|
||||
"dtype":
|
||||
str(vllm_config.model_config.dtype),
|
||||
"tensor_parallel_size":
|
||||
vllm_config.parallel_config.tensor_parallel_size,
|
||||
"block_size":
|
||||
vllm_config.cache_config.block_size,
|
||||
"gpu_memory_utilization":
|
||||
vllm_config.cache_config.gpu_memory_utilization,
|
||||
|
||||
# Quantization
|
||||
"quantization":
|
||||
vllm_config.model_config.quantization,
|
||||
"kv_cache_dtype":
|
||||
str(vllm_config.cache_config.cache_dtype),
|
||||
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(vllm_config.lora_config),
|
||||
"enable_prompt_adapter":
|
||||
bool(vllm_config.prompt_adapter_config),
|
||||
"enable_prefix_caching":
|
||||
vllm_config.cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
vllm_config.model_config.enforce_eager,
|
||||
"disable_custom_all_reduce":
|
||||
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user