mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 01:15:01 +08:00
76 lines
2.7 KiB
Python
76 lines
2.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import hashlib
|
|
from dataclasses import field
|
|
from typing import Any, Literal
|
|
|
|
import torch
|
|
from pydantic import ConfigDict, SkipValidation
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
from vllm.config.utils import config
|
|
|
|
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
|
|
|
|
|
@config
|
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
|
class DeviceConfig:
|
|
"""Configuration for the device to use for vLLM execution."""
|
|
|
|
device: SkipValidation[Device | torch.device | None] = "auto"
|
|
"""Device type for vLLM execution.
|
|
This parameter is deprecated and will be
|
|
removed in a future release.
|
|
It will now be set automatically based
|
|
on the current platform."""
|
|
device_type: str = field(init=False)
|
|
"""Device type from the current platform. This is set in
|
|
`__post_init__`."""
|
|
|
|
def compute_hash(self) -> str:
|
|
"""
|
|
WARNING: Whenever a new field is added to this config,
|
|
ensure that it is included in the factors list if
|
|
it affects the computation graph.
|
|
|
|
Provide a hash that uniquely identifies all the configs
|
|
that affect the structure of the computation
|
|
graph from input ids/embeddings to the final hidden states,
|
|
excluding anything before input ids/embeddings and after
|
|
the final hidden states.
|
|
"""
|
|
# no factors to consider.
|
|
# the device/platform information will be summarized
|
|
# by torch/vllm automatically.
|
|
factors: list[Any] = []
|
|
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
|
return hash_str
|
|
|
|
def __post_init__(self):
|
|
if self.device == "auto":
|
|
# Automated device type detection
|
|
from vllm.platforms import current_platform
|
|
|
|
self.device_type = current_platform.device_type
|
|
if not self.device_type:
|
|
raise RuntimeError(
|
|
"Failed to infer device type, please set "
|
|
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
|
"to turn on verbose logging to help debug the issue."
|
|
)
|
|
else:
|
|
# Device type is assigned explicitly
|
|
if isinstance(self.device, str):
|
|
self.device_type = self.device
|
|
elif isinstance(self.device, torch.device):
|
|
self.device_type = self.device.type
|
|
|
|
# Some device types require processing inputs on CPU
|
|
if self.device_type in ["tpu"]:
|
|
self.device = None
|
|
else:
|
|
# Set device with device type
|
|
self.device = torch.device(self.device_type)
|