[Misc] Add backup hash algorithm for FIPS constrained environments (#28795)

Signed-off-by: George D. Torres <gdavtor@gmail.com>
Signed-off-by: George D. Torres <41129492+geodavic@users.noreply.github.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
George D. Torres 2025-11-25 18:50:22 -06:00 committed by GitHub
parent 12866af748
commit 56531b79cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 56 additions and 38 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect
import os
import pickle
@ -14,6 +13,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
try:
from torch._dynamo.aot_compile import SerializableCallable
@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
return hashlib.md5(
return safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import hashlib
import os
from collections.abc import Callable
from contextlib import ExitStack
@ -16,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
def initialize_cache(
@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
def initialize_cache(

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import Any, Literal
@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@ -45,7 +45,7 @@ class DeviceConfig:
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
@ -79,7 +79,7 @@ class KVTransferConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.model_executor.model_loader import LoadFormats
@ -104,7 +104,7 @@ class LoadConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("load_format", mode="after")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any, Literal
import torch
@ -11,6 +10,7 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -74,7 +74,7 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
@ -216,7 +216,7 @@ class MultiModalConfig:
if self.mm_encoder_attn_backend is not None
else None
]
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def get_limit_per_prompt(self, modality: str) -> int:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from functools import cached_property
from typing import Any, Literal, cast
@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
DetailedTraceModules = Literal["model", "worker", "all"]
@ -78,7 +78,7 @@ class ObservabilityConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("show_hidden_metrics_for_version")

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
@ -102,7 +102,7 @@ class PoolerConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
@ -178,7 +178,7 @@ class SchedulerConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING:
@ -162,7 +162,7 @@ class SpeculativeConfig:
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@staticmethod

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Literal
from pydantic import model_validator
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
StructuredOutputsBackend = Literal[
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
@ -58,7 +58,7 @@ class StructuredOutputsConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")

View File

@ -3,7 +3,6 @@
import copy
import getpass
import hashlib
import json
import os
import tempfile
@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
@ -193,7 +193,7 @@ class VllmConfig:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(
additional_config_hash = safe_hash(
json.dumps(additional_config, sort_keys=True).encode(),
usedforsecurity=False,
).hexdigest()
@ -204,9 +204,9 @@ class VllmConfig:
vllm_factors.append("None")
factors.append(vllm_factors)
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int:

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:

View File

@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
import hashlib
import importlib
import json
import os
@ -32,6 +31,7 @@ from vllm.config import (
from vllm.logger import init_logger
from vllm.logging_utils import logtime
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash
from .interfaces import (
has_inner_state,
@ -654,7 +654,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
if model_path.exists():
with open(model_path, "rb") as f:
module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
mi = self._load_modelinfo_from_cache(module_hash)
if mi is not None:

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import hashlib
import pickle
from _hashlib import HASH, UnsupportedDigestmodError
from collections.abc import Callable
from typing import Any
@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH:
"""Hash for configs, defaulting to md5 but falling back to sha256
in FIPS constrained environments.
Args:
data: bytes
usedforsecurity: Whether the hash is used for security purposes
Returns:
Hash object
"""
try:
return hashlib.md5(data, usedforsecurity=usedforsecurity)
except (UnsupportedDigestmodError, ValueError):
return hashlib.sha256(data)