Add get_dtype_size

This commit is contained in:
Woosuk Kwon 2024-04-01 06:33:06 +00:00
parent 5083aa9092
commit 27c592b97b
2 changed files with 8 additions and 6 deletions

View File

@ -444,3 +444,8 @@ def maybe_expand_dim(tensor: torch.Tensor,
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()

View File

@ -6,7 +6,8 @@ import torch
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available,
get_dtype_size)
logger = init_logger(__name__)
@ -98,9 +99,5 @@ class CacheEngine:
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype_size = _get_dtype_size(dtype)
dtype_size = get_dtype_size(dtype)
return dtype_size * total
def _get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()