mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 01:21:19 +08:00
Add get_dtype_size
This commit is contained in:
parent
5083aa9092
commit
27c592b97b
@ -444,3 +444,8 @@ def maybe_expand_dim(tensor: torch.Tensor,
|
||||
if tensor.ndim < target_dims:
|
||||
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||
return tensor
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
@ -6,7 +6,8 @@ import torch
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available,
|
||||
get_dtype_size)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -98,9 +99,5 @@ class CacheEngine:
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||
dtype_size = _get_dtype_size(dtype)
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user