vllm/vllm/config/device.py
Harry Mellor 8fcaaf6a16
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-10-12 09:51:31 -07:00

76 lines
2.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import Any, Literal
import torch
from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Device | torch.device | None] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError(
"Failed to infer device type, please set "
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
"to turn on verbose logging to help debug the issue."
)
else:
# Device type is assigned explicitly
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)