diff --git a/vllm/utils.py b/vllm/utils.py index 780269f7e8ff5..442b7945d3209 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -444,3 +444,8 @@ def maybe_expand_dim(tensor: torch.Tensor, if tensor.ndim < target_dims: tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) return tensor + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 27d1727cd16a3..bdc758cb8f03f 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,8 @@ import torch from vllm.attention import get_attn_backend from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available, + get_dtype_size) logger = init_logger(__name__) @@ -98,9 +99,5 @@ class CacheEngine: dtype = model_config.dtype else: dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - dtype_size = _get_dtype_size(dtype) + dtype_size = get_dtype_size(dtype) return dtype_size * total - - -def _get_dtype_size(dtype: torch.dtype) -> int: - return torch.tensor([], dtype=dtype).element_size()