[Core] Rework dtype resolution (#18751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-06-01 11:04:23 +08:00 committed by GitHub
parent 1bc86a3da1
commit 6aa8f9a4e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 314 additions and 119 deletions

View File

@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
@ -69,7 +68,6 @@ def test_models(
hf_runner,
model: str,
backend: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
enable_prompt_embeds: bool,
@ -97,7 +95,7 @@ def test_models(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]
with hf_runner(model, dtype=dtype) as hf_model:
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
if enable_prompt_embeds:
with torch.no_grad():
@ -106,7 +104,6 @@ def test_models(
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
enable_prompt_embeds=enable_prompt_embeds,
gpu_memory_utilization=0.7) as vllm_model:

View File

@ -324,7 +324,12 @@ class HfRunner:
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
self.dtype = torch_dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)

View File

@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner,
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
MTEB_EMBED_TASKS)
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
model_dtype = getattr(
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
vllm_dtype)
with set_default_torch_dtype(model_dtype) and hf_runner(
with set_default_torch_dtype(vllm_dtype) and hf_runner(
model_info.name, is_sentence_transformer=True,
dtype=model_dtype) as hf_model:
dtype=vllm_dtype) as hf_model:
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
print("VLLM:", vllm_dtype, vllm_main_score)
print("SentenceTransformer:", model_dtype, st_main_score)
print("VLLM:", vllm_main_score)
print("SentenceTransformers:", st_main_score)
print("Difference:", st_main_score - vllm_main_score)
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)

View File

@ -43,6 +43,6 @@ def test_models(
# the tolerance value of 1e-2 is selected based on the
# half datatype tests in
# tests/models/embedding/language/test_embedding.py
# tests/models/language/pooling/test_embedding.py
assert torch.allclose(hf_output, vllm_output,
1e-3 if dtype == "float" else 1e-2)

View File

@ -30,13 +30,11 @@ from ...utils import check_embeddings_close
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
],
)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype: str,
monkeypatch,
) -> None:
@ -58,13 +56,11 @@ def test_models(
# So we need to strip the input texts to avoid test failing.
example_prompts = [str(s).strip() for s in example_prompts]
with hf_runner(model, dtype=dtype,
is_sentence_transformer=True) as hf_model:
with hf_runner(model, is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)
with vllm_runner(model,
task="embed",
dtype=dtype,
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)

View File

@ -100,6 +100,7 @@ def run_test(
with vllm_runner(
model,
dtype="half",
max_model_len=448,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,

View File

@ -40,7 +40,7 @@ def _test_processing_correctness(
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="float16",
dtype="auto",
revision=None,
hf_overrides=model_info.hf_overrides,
)

View File

@ -103,7 +103,7 @@ class TestTwoTokenBadWord:
add_special_tokens=False)[0]
def test_two_token_bad_word(self, vllm_runner):
with vllm_runner(self.MODEL) as llm:
with vllm_runner(self.MODEL, dtype="half") as llm:
output_token_ids = self._generate(llm)
assert output_token_ids[:2] == [
self.target_token_id1, self.target_token_id2

View File

@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
@ -567,12 +568,65 @@ def test_lru_cache():
assert 6 in cache
# yapf: disable
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[
# Different precision_levels
(torch.bool, torch.int8, True),
(torch.bool, torch.float16, True),
(torch.bool, torch.complex32, True),
(torch.int64, torch.bool, False),
(torch.int64, torch.float16, True),
(torch.int64, torch.complex32, True),
(torch.float64, torch.bool, False),
(torch.float64, torch.int8, False),
(torch.float64, torch.complex32, True),
(torch.complex128, torch.bool, False),
(torch.complex128, torch.int8, False),
(torch.complex128, torch.float16, False),
# precision_level=0
(torch.bool, torch.bool, True),
# precision_level=1
(torch.int8, torch.int16, True),
(torch.int16, torch.int8, False),
(torch.uint8, torch.int8, False),
(torch.int8, torch.uint8, False),
# precision_level=2
(torch.float16, torch.float32, True),
(torch.float32, torch.float16, False),
(torch.bfloat16, torch.float32, True),
(torch.float32, torch.bfloat16, False),
# precision_level=3
(torch.complex32, torch.complex64, True),
(torch.complex64, torch.complex32, False),
],
)
# yapf: enable
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
# yapf: disable
@pytest.mark.parametrize(
("dtypes", "expected_result"),
[
([torch.bool], torch.bool),
([torch.bool, torch.int8], torch.int8),
([torch.bool, torch.int8, torch.float16], torch.float16),
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
],
)
# yapf: enable
def test_common_broadcastable_dtype(dtypes, expected_result):
assert common_broadcastable_dtype(dtypes) == expected_result
def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")
def build_ctx():
return pytest.raises(ModuleNotFoundError,
match="No module named")
return pytest.raises(ModuleNotFoundError, match="No module named")
with build_ctx():
int(placeholder)
@ -608,6 +662,7 @@ def test_placeholder_module_error_handling():
_ = placeholder_attr.module
# yapf: disable
@pytest.mark.parametrize(
"obj,key1,key2",
[
@ -618,6 +673,7 @@ def test_placeholder_module_error_handling():
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
])
# yapf: enable
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
assert key1 not in obj
def test_model_specification(parser_with_config,
cli_config_file,
def test_model_specification(parser_with_config, cli_config_file,
cli_config_file_with_model):
# Test model in CLI takes precedence over config
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model
])
args = parser_with_config.parse_args(
['serve', 'cli-model', '--config', cli_config_file_with_model])
assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel'
# Test model from config file works
args = parser_with_config.parse_args([
'serve', '--config', cli_config_file_with_model,
'serve',
'--config',
cli_config_file_with_model,
])
assert args.model == 'config-model'
assert args.served_model_name == 'mymodel'
@ -655,16 +711,18 @@ def test_model_specification(parser_with_config,
# Test using --model option raises error
with pytest.raises(
ValueError,
match=(
"With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."
),
match=
("With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."),
):
parser_with_config.parse_args(['serve', '--model', 'my-model'])
# Test other config values are preserved
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model,
'serve',
'cli-model',
'--config',
cli_config_file_with_model,
])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
assert hash != 0
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
byteorder="big")
# hashing again, returns the same value
assert hash == sha256(input)
@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
]
)
])
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected
@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
]
)
])
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)
@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
# Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
assert zsock.getsockopt(
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
# Clean up
zsock.close()

View File

@ -24,6 +24,7 @@ import torch
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
model_validator)
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated, runtime_checkable
@ -42,15 +43,16 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
try_get_generation_config, try_get_safetensors_metadata, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
LayerBlockType, common_broadcastable_dtype,
cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
@ -540,7 +542,24 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
self.model,
self.hf_config,
self.dtype,
is_pooling_model=self.runner_type == "pooling",
revision=self.revision,
)
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
@ -597,16 +616,6 @@ class ModelConfig:
raise ValueError(
"`override_neuron_config` is only supported on Neuron.")
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config()
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
@ -692,7 +701,6 @@ class ModelConfig:
self.model, self.revision)
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
@ -3074,13 +3082,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] #
# model_type -> reason
_FLOAT16_NOT_SUPPORTED_MODELS = {
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
}
def _get_and_verify_dtype(
def _is_valid_dtype(model_type: str, dtype: torch.dtype):
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103
return False
return True
def _check_valid_dtype(model_type: str, dtype: torch.dtype):
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16:
reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type]
raise ValueError(f"The model type {model_type!r} "
f"does not support float16. Reason: {reason}")
return True
def _find_dtype(
model_id: str,
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
*,
revision: Optional[str],
):
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
@ -3092,75 +3124,111 @@ def _get_and_verify_dtype(
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
for file_mt in files_mt.values()
for dtype_str in file_mt.parameter_count
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
return common_broadcastable_dtype(param_dtypes)
if config_dtype is None:
config_dtype = torch.float32
return config_dtype
def _resolve_auto_dtype(
model_type: str,
config_dtype: torch.dtype,
*,
is_pooling_model: bool,
):
from vllm.platforms import current_platform
supported_dtypes = [
dtype for dtype in current_platform.supported_dtypes
if _is_valid_dtype(model_type, dtype)
]
if is_pooling_model and torch.float16 in supported_dtypes:
preferred_dtype = torch.float16
else:
preferred_dtype = supported_dtypes[0]
# Downcast for float32 models
if config_dtype == torch.float32:
config_dtype = preferred_dtype
if config_dtype in supported_dtypes:
return config_dtype
# Ensure device compatibility
device_name = current_platform.get_device_name()
device_capability = current_platform.get_device_capability()
if device_capability is None:
device_str = f"{device_name!r}"
else:
version_str = device_capability.as_version_str()
device_str = f"{device_name!r} (with compute capability {version_str})"
logger.warning(
"Your device %s doesn't support %s. "
"Falling back to %s for compatibility.",
device_str,
config_dtype,
preferred_dtype,
)
return preferred_dtype
def _get_and_verify_dtype(
model_id: str,
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
*,
is_pooling_model: bool,
revision: Optional[str] = None,
) -> torch.dtype:
config_dtype = _find_dtype(model_id, config, revision=revision)
model_type = config.model_type
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
# Set default dtype from model config
if config_dtype == torch.float32:
# Following common practice, we use float16 for float32 models
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.warning(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
torch_dtype = _resolve_auto_dtype(
model_type,
config_dtype,
is_pooling_model=is_pooling_model,
)
torch_dtype = torch.bfloat16
# Deal with torch dtype fallback for device compatibility.
from vllm.platforms import current_platform
if torch_dtype not in current_platform.supported_dtypes:
device_name = current_platform.get_device_name()
if ((capability := current_platform.get_device_capability())
is None):
compute_str = ""
else:
version_str = capability.as_version_str()
compute_str = f" (with compute capability {version_str})"
fallback_dtype = current_platform.supported_dtypes[0]
logger.warning(
"Your %s device%s doesn't support %s. " \
"Falling back to %s for compatibility.",
device_name, compute_str, torch_dtype, fallback_dtype
)
torch_dtype = fallback_dtype
if current_platform.is_hpu() and torch_dtype == torch.float16:
logger.warning(
"For HPU, we cast models to bfloat16 instead of "
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
raise ValueError(f"Unknown dtype: {dtype!r}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
_check_valid_dtype(model_type, torch_dtype)
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

View File

@ -28,7 +28,7 @@ class CpuPlatform(Platform):
dispatch_key: str = "CPU"
@property
def supported_dtypes(self) -> list:
def supported_dtypes(self) -> list[torch.dtype]:
if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
return [torch.bfloat16, torch.float32]
elif sys.platform.startswith(

View File

@ -4,12 +4,12 @@ import enum
import json
import os
import time
from functools import cache
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, TypeVar, Union
import huggingface_hub
from huggingface_hub import hf_hub_download
from huggingface_hub import get_safetensors_metadata, hf_hub_download
from huggingface_hub import list_repo_files as hf_list_repo_files
from huggingface_hub import try_to_load_from_cache
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
@ -93,10 +93,15 @@ class ConfigFormat(str, enum.Enum):
MISTRAL = "mistral"
def with_retry(func: Callable[[], Any],
_R = TypeVar("_R")
def with_retry(
func: Callable[[], _R],
log_msg: str,
max_retries: int = 2,
retry_delay: int = 2):
retry_delay: int = 2,
) -> _R:
for attempt in range(max_retries):
try:
return func()
@ -109,6 +114,8 @@ def with_retry(func: Callable[[], Any],
time.sleep(retry_delay)
retry_delay *= 2
raise AssertionError("Should not be reached")
# @cache doesn't cache exceptions
@cache
@ -840,3 +847,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return resolve_obj_by_qualname(function_name)()
else:
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
def try_get_safetensors_metadata(
model: str,
*,
revision: Optional[str] = None,
):
get_safetensors_metadata_partial = partial(
get_safetensors_metadata,
model,
revision=revision,
token=os.getenv('HF_TOKEN', None),
)
try:
return with_retry(get_safetensors_metadata_partial,
"Error retrieving safetensors")
except Exception:
return None

View File

@ -37,8 +37,8 @@ from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
_ArgumentGroup)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, KeysView, Mapping)
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
Hashable, Iterable, Iterator, KeysView, Mapping)
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
@ -979,6 +979,53 @@ def get_dtype_size(dtype: torch.dtype) -> int:
return torch.tensor([], dtype=dtype).element_size()
# bool = 0, int = 1, float = 2, complex = 3
def _get_precision_level(dtype: torch.dtype) -> int:
# NOTE: Complex dtypes return `is_floating_point=False`
return ((dtype != torch.bool) + dtype.is_floating_point +
dtype.is_complex * 2)
def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype):
"""
Test whether it is lossless to cast a tensor from
`src_dtype` to `tgt_dtype`.
"""
if src_dtype == tgt_dtype:
return True
src_level = _get_precision_level(src_dtype)
tgt_level = _get_precision_level(tgt_dtype)
if src_level < tgt_level:
return True
if src_level > tgt_level:
return False
# Compare integral types
if not src_dtype.is_floating_point and not src_dtype.is_complex:
src_info = torch.iinfo(src_dtype)
tgt_info = torch.iinfo(tgt_dtype)
return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
# Compare floating-point types
src_info = torch.finfo(src_dtype)
tgt_info = torch.finfo(tgt_dtype)
return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
and src_info.resolution >= tgt_info.resolution)
def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
"""
Get the common `dtype` where all of the other `dtypes` can be
cast to it without losing any information.
"""
return max(
dtypes,
key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes),
)
# `collections` helpers
def is_list_of(
value: object,