mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 09:41:48 +08:00
[Misc] Add backup hash algorithm for FIPS constrained environments (#28795)
Signed-off-by: George D. Torres <gdavtor@gmail.com> Signed-off-by: George D. Torres <41129492+geodavic@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
12866af748
commit
56531b79cc
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
@ -14,6 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
try:
|
||||
from torch._dynamo.aot_compile import SerializableCallable
|
||||
@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
return hashlib.md5(
|
||||
return safe_hash(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
@ -16,6 +15,7 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
@ -45,7 +45,7 @@ class DeviceConfig:
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
@ -79,7 +79,7 @@ class KVTransferConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
@ -104,7 +104,7 @@ class LoadConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("load_format", mode="after")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
@ -11,6 +10,7 @@ from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -74,7 +74,7 @@ class LoRAConfig:
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
@ -216,7 +216,7 @@ class MultiModalConfig:
|
||||
if self.mm_encoder_attn_backend is not None
|
||||
else None
|
||||
]
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
@ -78,7 +78,7 @@ class ObservabilityConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("show_hidden_metrics_for_version")
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -102,7 +102,7 @@ class PoolerConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||
@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -178,7 +178,7 @@ class SchedulerConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -162,7 +162,7 @@ class SpeculativeConfig:
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
StructuredOutputsBackend = Literal[
|
||||
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
|
||||
@ -58,7 +58,7 @@ class StructuredOutputsConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
import copy
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
@ -193,7 +193,7 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
additional_config_hash = safe_hash(
|
||||
json.dumps(additional_config, sort_keys=True).encode(),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
@ -204,9 +204,9 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
factors.append(vllm_factors)
|
||||
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
return hash_str
|
||||
|
||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
|
||||
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
|
||||
@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
|
||||
`tests/models/registry.py` with example HuggingFace models for it.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
@ -32,6 +31,7 @@ from vllm.config import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import logtime
|
||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .interfaces import (
|
||||
has_inner_state,
|
||||
@ -654,7 +654,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
|
||||
if model_path.exists():
|
||||
with open(model_path, "rb") as f:
|
||||
module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
|
||||
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
|
||||
|
||||
mi = self._load_modelinfo_from_cache(module_hash)
|
||||
if mi is not None:
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
from _hashlib import HASH, UnsupportedDigestmodError
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||
return sha256_cbor
|
||||
|
||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||
|
||||
|
||||
def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH:
|
||||
"""Hash for configs, defaulting to md5 but falling back to sha256
|
||||
in FIPS constrained environments.
|
||||
|
||||
Args:
|
||||
data: bytes
|
||||
usedforsecurity: Whether the hash is used for security purposes
|
||||
|
||||
Returns:
|
||||
Hash object
|
||||
"""
|
||||
try:
|
||||
return hashlib.md5(data, usedforsecurity=usedforsecurity)
|
||||
except (UnsupportedDigestmodError, ValueError):
|
||||
return hashlib.sha256(data)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user