mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:34:59 +08:00
[Bugfix] Enable V1 usage stats (#16986)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
46e678bcff
commit
ed50f46641
@ -19,6 +19,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
|
from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
_config_home = envs.VLLM_CONFIG_ROOT
|
_config_home = envs.VLLM_CONFIG_ROOT
|
||||||
@ -168,10 +169,9 @@ class UsageMessage:
|
|||||||
# Platform information
|
# Platform information
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
device_property = torch.cuda.get_device_properties(0)
|
self.gpu_count = cuda_device_count_stateless()
|
||||||
self.gpu_count = torch.cuda.device_count()
|
self.gpu_type, self.gpu_memory_per_device = (
|
||||||
self.gpu_type = device_property.name
|
cuda_get_device_properties(0, ("name", "total_memory")))
|
||||||
self.gpu_memory_per_device = device_property.total_memory
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
self.cuda_runtime = torch.version.cuda
|
self.cuda_runtime = torch.version.cuda
|
||||||
self.provider = _detect_cloud_provider()
|
self.provider = _detect_cloud_provider()
|
||||||
|
|||||||
@ -38,11 +38,13 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
||||||
Iterable, Iterator, KeysView, Mapping)
|
Iterable, Iterator, KeysView, Mapping)
|
||||||
|
from concurrent.futures.process import ProcessPoolExecutor
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||||
Optional, Tuple, Type, TypeVar, Union, cast, overload)
|
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
|
||||||
|
overload)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import cachetools
|
import cachetools
|
||||||
@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool:
|
|||||||
return torch.cuda.is_initialized()
|
return torch.cuda.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_get_device_properties(device,
|
||||||
|
names: Sequence[str],
|
||||||
|
init_cuda=False) -> tuple[Any, ...]:
|
||||||
|
"""Get specified CUDA device property values without initializing CUDA in
|
||||||
|
the current process."""
|
||||||
|
if init_cuda or cuda_is_initialized():
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
return tuple(getattr(props, name) for name in names)
|
||||||
|
|
||||||
|
# Run in subprocess to avoid initializing CUDA as a side effect.
|
||||||
|
mp_ctx = multiprocessing.get_context("fork")
|
||||||
|
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
|
||||||
|
return executor.submit(cuda_get_device_properties, device, names,
|
||||||
|
True).result()
|
||||||
|
|
||||||
|
|
||||||
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
||||||
"""Make an instance method that weakly references
|
"""Make an instance method that weakly references
|
||||||
its associated instance and no-ops once that
|
its associated instance and no-ops once that
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.v1.executor.abstract import Executor
|
|||||||
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
|
||||||
StatLoggerBase)
|
StatLoggerBase)
|
||||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||||
|
from vllm.v1.utils import report_usage_stats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -114,6 +115,9 @@ class AsyncLLM(EngineClient):
|
|||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# If usage stat is enabled, collect relevant info.
|
||||||
|
report_usage_stats(vllm_config, usage_context)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_vllm_config(
|
def from_vllm_config(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from vllm.v1.engine.output_processor import OutputProcessor
|
|||||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
from vllm.v1.utils import report_usage_stats
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -99,6 +100,9 @@ class LLMEngine:
|
|||||||
# for v0 compatibility
|
# for v0 compatibility
|
||||||
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
|
||||||
|
|
||||||
|
# If usage stat is enabled, collect relevant info.
|
||||||
|
report_usage_stats(vllm_config, usage_context)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_vllm_config(
|
def from_vllm_config(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -12,6 +12,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
|
usage_message)
|
||||||
from vllm.utils import get_mp_context, kill_process_tree
|
from vllm.utils import get_mp_context, kill_process_tree
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
|||||||
Returns the sliced target tensor.
|
Returns the sliced target tensor.
|
||||||
"""
|
"""
|
||||||
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
|
||||||
|
"""Report usage statistics if enabled."""
|
||||||
|
|
||||||
|
if not is_usage_stats_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader import get_architecture_class_name
|
||||||
|
|
||||||
|
usage_message.report_usage(
|
||||||
|
get_architecture_class_name(vllm_config.model_config),
|
||||||
|
usage_context,
|
||||||
|
extra_kvs={
|
||||||
|
# Common configuration
|
||||||
|
"dtype":
|
||||||
|
str(vllm_config.model_config.dtype),
|
||||||
|
"tensor_parallel_size":
|
||||||
|
vllm_config.parallel_config.tensor_parallel_size,
|
||||||
|
"block_size":
|
||||||
|
vllm_config.cache_config.block_size,
|
||||||
|
"gpu_memory_utilization":
|
||||||
|
vllm_config.cache_config.gpu_memory_utilization,
|
||||||
|
|
||||||
|
# Quantization
|
||||||
|
"quantization":
|
||||||
|
vllm_config.model_config.quantization,
|
||||||
|
"kv_cache_dtype":
|
||||||
|
str(vllm_config.cache_config.cache_dtype),
|
||||||
|
|
||||||
|
# Feature flags
|
||||||
|
"enable_lora":
|
||||||
|
bool(vllm_config.lora_config),
|
||||||
|
"enable_prompt_adapter":
|
||||||
|
bool(vllm_config.prompt_adapter_config),
|
||||||
|
"enable_prefix_caching":
|
||||||
|
vllm_config.cache_config.enable_prefix_caching,
|
||||||
|
"enforce_eager":
|
||||||
|
vllm_config.model_config.enforce_eager,
|
||||||
|
"disable_custom_all_reduce":
|
||||||
|
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||||
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user