mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 04:28:00 +08:00
Add get_dtype_size
This commit is contained in:
parent
5083aa9092
commit
27c592b97b
@ -444,3 +444,8 @@ def maybe_expand_dim(tensor: torch.Tensor,
|
|||||||
if tensor.ndim < target_dims:
|
if tensor.ndim < target_dims:
|
||||||
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||||
|
"""Get the size of the data type in bytes."""
|
||||||
|
return torch.tensor([], dtype=dtype).element_size()
|
||||||
|
|||||||
@ -6,7 +6,8 @@ import torch
|
|||||||
from vllm.attention import get_attn_backend
|
from vllm.attention import get_attn_backend
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available,
|
||||||
|
get_dtype_size)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -98,9 +99,5 @@ class CacheEngine:
|
|||||||
dtype = model_config.dtype
|
dtype = model_config.dtype
|
||||||
else:
|
else:
|
||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||||
dtype_size = _get_dtype_size(dtype)
|
dtype_size = get_dtype_size(dtype)
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
|
||||||
return torch.tensor([], dtype=dtype).element_size()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user