mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Core] Rework dtype resolution (#18751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1bc86a3da1
commit
6aa8f9a4e7
@ -60,7 +60,6 @@ def _fix_prompt_embed_outputs(
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
@ -69,7 +68,6 @@ def test_models(
|
||||
hf_runner,
|
||||
model: str,
|
||||
backend: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
enable_prompt_embeds: bool,
|
||||
@ -97,7 +95,7 @@ def test_models(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
if enable_prompt_embeds:
|
||||
with torch.no_grad():
|
||||
@ -106,7 +104,6 @@ def test_models(
|
||||
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
|
||||
@ -324,7 +324,12 @@ class HfRunner:
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
self.device = self.get_default_device()
|
||||
self.dtype = torch_dtype = _get_and_verify_dtype(self.config, dtype)
|
||||
self.dtype = torch_dtype = _get_and_verify_dtype(
|
||||
self.model_name,
|
||||
self.config,
|
||||
dtype=dtype,
|
||||
is_pooling_model=is_sentence_transformer or is_cross_encoder,
|
||||
)
|
||||
|
||||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||||
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
||||
|
||||
@ -102,21 +102,18 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model),
|
||||
MTEB_EMBED_TASKS)
|
||||
vllm_dtype = vllm_model.model.llm_engine.model_config.dtype
|
||||
model_dtype = getattr(
|
||||
vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype",
|
||||
vllm_dtype)
|
||||
|
||||
with set_default_torch_dtype(model_dtype) and hf_runner(
|
||||
with set_default_torch_dtype(vllm_dtype) and hf_runner(
|
||||
model_info.name, is_sentence_transformer=True,
|
||||
dtype=model_dtype) as hf_model:
|
||||
dtype=vllm_dtype) as hf_model:
|
||||
|
||||
if hf_model_callback is not None:
|
||||
hf_model_callback(hf_model)
|
||||
|
||||
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
|
||||
|
||||
print("VLLM:", vllm_dtype, vllm_main_score)
|
||||
print("SentenceTransformer:", model_dtype, st_main_score)
|
||||
print("VLLM:", vllm_main_score)
|
||||
print("SentenceTransformers:", st_main_score)
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
|
||||
|
||||
@ -43,6 +43,6 @@ def test_models(
|
||||
|
||||
# the tolerance value of 1e-2 is selected based on the
|
||||
# half datatype tests in
|
||||
# tests/models/embedding/language/test_embedding.py
|
||||
# tests/models/language/pooling/test_embedding.py
|
||||
assert torch.allclose(hf_output, vllm_output,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
@ -30,13 +30,11 @@ from ...utils import check_embeddings_close
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
|
||||
@ -58,13 +56,11 @@ def test_models(
|
||||
# So we need to strip the input texts to avoid test failing.
|
||||
example_prompts = [str(s).strip() for s in example_prompts]
|
||||
|
||||
with hf_runner(model, dtype=dtype,
|
||||
is_sentence_transformer=True) as hf_model:
|
||||
with hf_runner(model, is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(model,
|
||||
task="embed",
|
||||
dtype=dtype,
|
||||
max_model_len=None,
|
||||
**vllm_extra_kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
@ -100,6 +100,7 @@ def run_test(
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype="half",
|
||||
max_model_len=448,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
|
||||
@ -40,7 +40,7 @@ def _test_processing_correctness(
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
dtype="auto",
|
||||
revision=None,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
)
|
||||
|
||||
@ -103,7 +103,7 @@ class TestTwoTokenBadWord:
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
|
||||
@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, deprecate_kwargs, get_open_port,
|
||||
bind_kv_cache, common_broadcastable_dtype,
|
||||
deprecate_kwargs, get_open_port, is_lossless_cast,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_zmq_path,
|
||||
supports_kw, swap_dict_values)
|
||||
@ -567,12 +568,65 @@ def test_lru_cache():
|
||||
assert 6 in cache
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("src_dtype", "tgt_dtype", "expected_result"),
|
||||
[
|
||||
# Different precision_levels
|
||||
(torch.bool, torch.int8, True),
|
||||
(torch.bool, torch.float16, True),
|
||||
(torch.bool, torch.complex32, True),
|
||||
(torch.int64, torch.bool, False),
|
||||
(torch.int64, torch.float16, True),
|
||||
(torch.int64, torch.complex32, True),
|
||||
(torch.float64, torch.bool, False),
|
||||
(torch.float64, torch.int8, False),
|
||||
(torch.float64, torch.complex32, True),
|
||||
(torch.complex128, torch.bool, False),
|
||||
(torch.complex128, torch.int8, False),
|
||||
(torch.complex128, torch.float16, False),
|
||||
# precision_level=0
|
||||
(torch.bool, torch.bool, True),
|
||||
# precision_level=1
|
||||
(torch.int8, torch.int16, True),
|
||||
(torch.int16, torch.int8, False),
|
||||
(torch.uint8, torch.int8, False),
|
||||
(torch.int8, torch.uint8, False),
|
||||
# precision_level=2
|
||||
(torch.float16, torch.float32, True),
|
||||
(torch.float32, torch.float16, False),
|
||||
(torch.bfloat16, torch.float32, True),
|
||||
(torch.float32, torch.bfloat16, False),
|
||||
# precision_level=3
|
||||
(torch.complex32, torch.complex64, True),
|
||||
(torch.complex64, torch.complex32, False),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
|
||||
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("dtypes", "expected_result"),
|
||||
[
|
||||
([torch.bool], torch.bool),
|
||||
([torch.bool, torch.int8], torch.int8),
|
||||
([torch.bool, torch.int8, torch.float16], torch.float16),
|
||||
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_common_broadcastable_dtype(dtypes, expected_result):
|
||||
assert common_broadcastable_dtype(dtypes) == expected_result
|
||||
|
||||
|
||||
def test_placeholder_module_error_handling():
|
||||
placeholder = PlaceholderModule("placeholder_1234")
|
||||
|
||||
def build_ctx():
|
||||
return pytest.raises(ModuleNotFoundError,
|
||||
match="No module named")
|
||||
return pytest.raises(ModuleNotFoundError, match="No module named")
|
||||
|
||||
with build_ctx():
|
||||
int(placeholder)
|
||||
@ -608,6 +662,7 @@ def test_placeholder_module_error_handling():
|
||||
_ = placeholder_attr.module
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
@ -618,6 +673,7 @@ def test_placeholder_module_error_handling():
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
swap_dict_values(obj, key1, key2)
|
||||
@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
|
||||
assert key1 not in obj
|
||||
|
||||
|
||||
def test_model_specification(parser_with_config,
|
||||
cli_config_file,
|
||||
def test_model_specification(parser_with_config, cli_config_file,
|
||||
cli_config_file_with_model):
|
||||
# Test model in CLI takes precedence over config
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', 'cli-model', '--config', cli_config_file_with_model
|
||||
])
|
||||
args = parser_with_config.parse_args(
|
||||
['serve', 'cli-model', '--config', cli_config_file_with_model])
|
||||
assert args.model_tag == 'cli-model'
|
||||
assert args.served_model_name == 'mymodel'
|
||||
|
||||
# Test model from config file works
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', '--config', cli_config_file_with_model,
|
||||
'serve',
|
||||
'--config',
|
||||
cli_config_file_with_model,
|
||||
])
|
||||
assert args.model == 'config-model'
|
||||
assert args.served_model_name == 'mymodel'
|
||||
@ -654,17 +710,19 @@ def test_model_specification(parser_with_config,
|
||||
|
||||
# Test using --model option raises error
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."
|
||||
),
|
||||
ValueError,
|
||||
match=
|
||||
("With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."),
|
||||
):
|
||||
parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||
|
||||
# Test other config values are preserved
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', 'cli-model', '--config', cli_config_file_with_model,
|
||||
'serve',
|
||||
'cli-model',
|
||||
'--config',
|
||||
cli_config_file_with_model,
|
||||
])
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
@ -673,7 +731,7 @@ def test_model_specification(parser_with_config,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
|
||||
assert hash != 0
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
|
||||
byteorder="big")
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
|
||||
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
|
||||
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
|
||||
("inproc://some_identifier", ("inproc", "some_identifier", "")),
|
||||
]
|
||||
)
|
||||
])
|
||||
def test_split_zmq_path(path, expected):
|
||||
assert split_zmq_path(path) == expected
|
||||
|
||||
@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
|
||||
"tcp://127.0.0.1", # Missing port
|
||||
"tcp://[::1]", # Missing port for IPv6
|
||||
"tcp://:5555", # Missing host
|
||||
]
|
||||
)
|
||||
])
|
||||
def test_split_zmq_path_invalid(invalid_path):
|
||||
with pytest.raises(ValueError):
|
||||
split_zmq_path(invalid_path)
|
||||
@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
|
||||
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
|
||||
|
||||
# Verify that the IPV6 option is set
|
||||
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
|
||||
assert zsock.getsockopt(
|
||||
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
|
||||
|
||||
# Clean up
|
||||
zsock.close()
|
||||
|
||||
202
vllm/config.py
202
vllm/config.py
@ -24,6 +24,7 @@ import torch
|
||||
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
|
||||
model_validator)
|
||||
from pydantic.dataclasses import dataclass
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated, runtime_checkable
|
||||
@ -42,15 +43,16 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, uses_mrope)
|
||||
try_get_generation_config, try_get_safetensors_metadata, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
|
||||
LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
LayerBlockType, common_broadcastable_dtype,
|
||||
cuda_device_count_stateless, get_cpu_memory,
|
||||
get_open_port, is_torch_equal_or_newer, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
@ -540,7 +542,24 @@ class ModelConfig:
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
self.hf_config,
|
||||
self.dtype,
|
||||
is_pooling_model=self.runner_type == "pooling",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Workaround for Gemma 2 which uses interleaved sliding window
|
||||
# attention, but it's not specified in its config. TODO: remove this
|
||||
@ -597,16 +616,6 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
"`override_neuron_config` is only supported on Neuron.")
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
@ -692,7 +701,6 @@ class ModelConfig:
|
||||
self.model, self.revision)
|
||||
|
||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||
|
||||
if self.runner_type == "pooling":
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.override_pooler_config = PoolerConfig(
|
||||
@ -3074,13 +3082,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] #
|
||||
# model_type -> reason
|
||||
_FLOAT16_NOT_SUPPORTED_MODELS = {
|
||||
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
}
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
def _is_valid_dtype(model_type: str, dtype: torch.dtype):
|
||||
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_valid_dtype(model_type: str, dtype: torch.dtype):
|
||||
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16:
|
||||
reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type]
|
||||
raise ValueError(f"The model type {model_type!r} "
|
||||
f"does not support float16. Reason: {reason}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _find_dtype(
|
||||
model_id: str,
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
*,
|
||||
revision: Optional[str],
|
||||
):
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
@ -3092,75 +3124,111 @@ def _get_and_verify_dtype(
|
||||
if config_dtype is None and hasattr(config, "vision_config"):
|
||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
||||
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
|
||||
|
||||
if repo_mt and (files_mt := repo_mt.files_metadata):
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
|
||||
for file_mt in files_mt.values()
|
||||
for dtype_str in file_mt.parameter_count
|
||||
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
|
||||
}
|
||||
|
||||
if param_dtypes:
|
||||
return common_broadcastable_dtype(param_dtypes)
|
||||
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
return config_dtype
|
||||
|
||||
|
||||
def _resolve_auto_dtype(
|
||||
model_type: str,
|
||||
config_dtype: torch.dtype,
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
supported_dtypes = [
|
||||
dtype for dtype in current_platform.supported_dtypes
|
||||
if _is_valid_dtype(model_type, dtype)
|
||||
]
|
||||
|
||||
if is_pooling_model and torch.float16 in supported_dtypes:
|
||||
preferred_dtype = torch.float16
|
||||
else:
|
||||
preferred_dtype = supported_dtypes[0]
|
||||
|
||||
# Downcast for float32 models
|
||||
if config_dtype == torch.float32:
|
||||
config_dtype = preferred_dtype
|
||||
|
||||
if config_dtype in supported_dtypes:
|
||||
return config_dtype
|
||||
|
||||
# Ensure device compatibility
|
||||
device_name = current_platform.get_device_name()
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
if device_capability is None:
|
||||
device_str = f"{device_name!r}"
|
||||
else:
|
||||
version_str = device_capability.as_version_str()
|
||||
device_str = f"{device_name!r} (with compute capability {version_str})"
|
||||
|
||||
logger.warning(
|
||||
"Your device %s doesn't support %s. "
|
||||
"Falling back to %s for compatibility.",
|
||||
device_str,
|
||||
config_dtype,
|
||||
preferred_dtype,
|
||||
)
|
||||
|
||||
return preferred_dtype
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
model_id: str,
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
revision: Optional[str] = None,
|
||||
) -> torch.dtype:
|
||||
config_dtype = _find_dtype(model_id, config, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
# Set default dtype from model config
|
||||
if config_dtype == torch.float32:
|
||||
# Following common practice, we use float16 for float32 models
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
if config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, we cast models to bfloat16 instead of using "
|
||||
"float16 by default. This is because float16 does not work."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
# Deal with torch dtype fallback for device compatibility.
|
||||
from vllm.platforms import current_platform
|
||||
if torch_dtype not in current_platform.supported_dtypes:
|
||||
device_name = current_platform.get_device_name()
|
||||
|
||||
if ((capability := current_platform.get_device_capability())
|
||||
is None):
|
||||
compute_str = ""
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f" (with compute capability {version_str})"
|
||||
fallback_dtype = current_platform.supported_dtypes[0]
|
||||
logger.warning(
|
||||
"Your %s device%s doesn't support %s. " \
|
||||
"Falling back to %s for compatibility.",
|
||||
device_name, compute_str, torch_dtype, fallback_dtype
|
||||
)
|
||||
torch_dtype = fallback_dtype
|
||||
|
||||
if current_platform.is_hpu() and torch_dtype == torch.float16:
|
||||
logger.warning(
|
||||
"For HPU, we cast models to bfloat16 instead of "
|
||||
"using float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16" and config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, using float16 is unstable and might cause "
|
||||
"unexpected behavior. Please use bfloat16 or float32 instead.")
|
||||
torch_dtype = torch.float16
|
||||
torch_dtype = _resolve_auto_dtype(
|
||||
model_type,
|
||||
config_dtype,
|
||||
is_pooling_model=is_pooling_model,
|
||||
)
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
raise ValueError(f"Unknown dtype: {dtype!r}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
_check_valid_dtype(model_type, torch_dtype)
|
||||
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
|
||||
@ -28,7 +28,7 @@ class CpuPlatform(Platform):
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> list:
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
|
||||
return [torch.bfloat16, torch.float32]
|
||||
elif sys.platform.startswith(
|
||||
|
||||
@ -4,12 +4,12 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import cache
|
||||
from functools import cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from typing import Any, Callable, Literal, Optional, TypeVar, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import get_safetensors_metadata, hf_hub_download
|
||||
from huggingface_hub import list_repo_files as hf_list_repo_files
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
@ -93,10 +93,15 @@ class ConfigFormat(str, enum.Enum):
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def with_retry(func: Callable[[], Any],
|
||||
log_msg: str,
|
||||
max_retries: int = 2,
|
||||
retry_delay: int = 2):
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def with_retry(
|
||||
func: Callable[[], _R],
|
||||
log_msg: str,
|
||||
max_retries: int = 2,
|
||||
retry_delay: int = 2,
|
||||
) -> _R:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
@ -109,6 +114,8 @@ def with_retry(func: Callable[[], Any],
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
|
||||
raise AssertionError("Should not be reached")
|
||||
|
||||
|
||||
# @cache doesn't cache exceptions
|
||||
@cache
|
||||
@ -840,3 +847,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
||||
return resolve_obj_by_qualname(function_name)()
|
||||
else:
|
||||
return nn.Sigmoid() if config.num_labels == 1 else nn.Identity()
|
||||
|
||||
|
||||
def try_get_safetensors_metadata(
|
||||
model: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
get_safetensors_metadata_partial = partial(
|
||||
get_safetensors_metadata,
|
||||
model,
|
||||
revision=revision,
|
||||
token=os.getenv('HF_TOKEN', None),
|
||||
)
|
||||
|
||||
try:
|
||||
return with_retry(get_safetensors_metadata_partial,
|
||||
"Error retrieving safetensors")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@ -37,8 +37,8 @@ from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
|
||||
_ArgumentGroup)
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
||||
Iterable, Iterator, KeysView, Mapping)
|
||||
from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator,
|
||||
Hashable, Iterable, Iterator, KeysView, Mapping)
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
@ -979,6 +979,53 @@ def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
# bool = 0, int = 1, float = 2, complex = 3
|
||||
def _get_precision_level(dtype: torch.dtype) -> int:
|
||||
# NOTE: Complex dtypes return `is_floating_point=False`
|
||||
return ((dtype != torch.bool) + dtype.is_floating_point +
|
||||
dtype.is_complex * 2)
|
||||
|
||||
|
||||
def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype):
|
||||
"""
|
||||
Test whether it is lossless to cast a tensor from
|
||||
`src_dtype` to `tgt_dtype`.
|
||||
"""
|
||||
if src_dtype == tgt_dtype:
|
||||
return True
|
||||
|
||||
src_level = _get_precision_level(src_dtype)
|
||||
tgt_level = _get_precision_level(tgt_dtype)
|
||||
|
||||
if src_level < tgt_level:
|
||||
return True
|
||||
if src_level > tgt_level:
|
||||
return False
|
||||
|
||||
# Compare integral types
|
||||
if not src_dtype.is_floating_point and not src_dtype.is_complex:
|
||||
src_info = torch.iinfo(src_dtype)
|
||||
tgt_info = torch.iinfo(tgt_dtype)
|
||||
return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
|
||||
|
||||
# Compare floating-point types
|
||||
src_info = torch.finfo(src_dtype)
|
||||
tgt_info = torch.finfo(tgt_dtype)
|
||||
return (src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
|
||||
and src_info.resolution >= tgt_info.resolution)
|
||||
|
||||
|
||||
def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
|
||||
"""
|
||||
Get the common `dtype` where all of the other `dtypes` can be
|
||||
cast to it without losing any information.
|
||||
"""
|
||||
return max(
|
||||
dtypes,
|
||||
key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes),
|
||||
)
|
||||
|
||||
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user