diff --git a/.github/workflows/cleanup_pr_body.yml b/.github/workflows/cleanup_pr_body.yml index c3e132a536a42..861290ea43c87 100644 --- a/.github/workflows/cleanup_pr_body.yml +++ b/.github/workflows/cleanup_pr_body.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - name: Set up Python uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 diff --git a/.github/workflows/macos-smoke-test.yml b/.github/workflows/macos-smoke-test.yml index a183033c9adde..3a12c4b3a8300 100644 --- a/.github/workflows/macos-smoke-test.yml +++ b/.github/workflows/macos-smoke-test.yml @@ -12,7 +12,7 @@ jobs: timeout-minutes: 30 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: astral-sh/setup-uv@v7 with: diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index e21d13b8161f3..d5e70f30ef638 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -16,7 +16,7 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0 - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 with: python-version: "3.12" diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index 229d9862fb670..27d1e990c611e 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel( scalar_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, - const uint head_size) { + const uint head_size, const uint prefix_head_stride, + const uint output_head_stride) { using pack_128b_t = uint4; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; @@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel( const uint head_idx = token_head_idx % num_heads; const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. - const uint head_offset = - token_idx * num_heads * head_size + head_idx * head_size; - const scalar_t* prefix_head_ptr = prefix_output + head_offset; - const scalar_t* suffix_head_ptr = suffix_output + head_offset; - scalar_t* output_head_ptr = output + head_offset; + const uint src_head_offset = token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride; + const uint dst_head_offset = token_idx * num_heads * output_head_stride + + head_idx * output_head_stride; + const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; + const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; + scalar_t* output_head_ptr = output + dst_head_offset; float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; @@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel( reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ - num_heads, head_size); \ + num_heads, head_size, prefix_head_stride, output_head_stride); \ } /*@brief Merges the attention states from prefix and suffix @@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output, const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); + const uint prefix_head_stride = prefix_output.stride(1); + const uint output_head_stride = output.stride(1); const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); - TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1, - "output heads must be contiguous in memory"); - TORCH_CHECK( - prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1, - "prefix_output heads must be contiguous in memory"); - TORCH_CHECK( - suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1, - "suffix_output heads must be contiguous in memory"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); diff --git a/csrc/ops.h b/csrc/ops.h index f8bdc61aaa8ec..4bb7857b15032 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,14 +52,13 @@ void paged_attention_v2( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); -#ifndef USE_ROCM void merge_attn_states(torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); - +#ifndef USE_ROCM void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 14913bef13125..e9c96bb8b56cf 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int blocksparse_head_sliding_step) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); -#ifndef USE_ROCM // Merge attn states // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) @@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); - +#ifndef USE_ROCM ops.def( "convert_vertical_slash_indexes(" " Tensor! block_count, Tensor! block_offset, " diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 3c87a24afd9c7..74e4d778ded87 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -20,7 +20,11 @@ def merge_attn_states( num_query_heads = output.shape[1] head_size = output.shape[2] padded_head_size = triton.next_power_of_2(head_size) - + # We assume the output stride on num_head is not always as same as the + # `suffix_output` and `prefix_output`, as them might be padded by the attention + # backend. + prefix_head_stride = prefix_output.stride(1) + output_head_stride = output.stride(1) # TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead. merge_attn_states_kernel[(num_tokens, num_query_heads)]( output, @@ -29,6 +33,8 @@ def merge_attn_states( prefix_lse, suffix_output, suffix_lse, + prefix_head_stride, + output_head_stride, head_size, padded_head_size, output_lse is not None, @@ -43,6 +49,8 @@ def merge_attn_states_kernel( prefix_lse, # [NUM_HEADS, NUM_TOKENS] suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] suffix_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_head_stride, + output_head_stride, HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, OUTPUT_LSE: tl.constexpr, @@ -79,15 +87,15 @@ def merge_attn_states_kernel( head_mask = head_arange < HEAD_SIZE p_out = tl.load( prefix_output - + token_idx * num_heads * HEAD_SIZE - + head_idx * HEAD_SIZE + + token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride + head_arange, mask=head_mask, ) s_out = tl.load( suffix_output - + token_idx * num_heads * HEAD_SIZE - + head_idx * HEAD_SIZE + + token_idx * num_heads * prefix_head_stride + + head_idx * prefix_head_stride + head_arange, mask=head_mask, ) @@ -99,7 +107,10 @@ def merge_attn_states_kernel( s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale tl.store( - output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + output + + token_idx * num_heads * output_head_stride + + head_idx * output_head_stride + + head_arange, out, mask=head_mask, ) diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 6297d9f995aa4..ce482572b401b 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib import inspect import os import pickle @@ -14,6 +13,7 @@ import vllm.envs as envs from vllm.config import VllmConfig, get_current_vllm_config from vllm.config.utils import hash_factors from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash try: from torch._dynamo.aot_compile import SerializableCallable @@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: # e.g. exec(). We can't actually check these. continue hash_content.append(content) - return hashlib.md5( + return safe_hash( "\n".join(hash_content).encode(), usedforsecurity=False ).hexdigest() diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 11cf0f85c1787..7deaba1a99fad 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import copy -import hashlib import os from collections.abc import Callable from contextlib import ExitStack @@ -16,6 +15,7 @@ import torch.fx as fx import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.utils.hashing import safe_hash from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface): def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = hashlib.md5( - str(factors).encode(), usedforsecurity=False - ).hexdigest()[:10] + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ + :10 + ] return hash_str def initialize_cache( @@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface): def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = hashlib.md5( - str(factors).encode(), usedforsecurity=False - ).hexdigest()[:10] + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ + :10 + ] return hash_str def initialize_cache( diff --git a/vllm/config/device.py b/vllm/config/device.py index e85cd15de8cf4..85662ddff76b7 100644 --- a/vllm/config/device.py +++ b/vllm/config/device.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from dataclasses import field from typing import Any, Literal @@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation from pydantic.dataclasses import dataclass from vllm.config.utils import config +from vllm.utils.hashing import safe_hash Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] @@ -45,7 +45,7 @@ class DeviceConfig: # the device/platform information will be summarized # by torch/vllm automatically. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self): diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index dfd7ef63712a3..88f8b91c292bb 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib import uuid from dataclasses import field from typing import Any, Literal, get_args @@ -9,6 +8,7 @@ from typing import Any, Literal, get_args from pydantic.dataclasses import dataclass from vllm.config.utils import config +from vllm.utils.hashing import safe_hash KVProducer = Literal["kv_producer", "kv_both"] KVConsumer = Literal["kv_consumer", "kv_both"] @@ -79,7 +79,7 @@ class KVTransferConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def __post_init__(self) -> None: diff --git a/vllm/config/load.py b/vllm/config/load.py index e424f8c5edb62..579a0bc31020e 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from typing import TYPE_CHECKING, Any from pydantic import Field, field_validator @@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass from vllm.config.utils import config from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash if TYPE_CHECKING: from vllm.model_executor.model_loader import LoadFormats @@ -104,7 +104,7 @@ class LoadConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @field_validator("load_format", mode="after") diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 072e0ec2104f5..6a8fd6359aadd 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from typing import TYPE_CHECKING, Any, Literal import torch @@ -11,6 +10,7 @@ from typing_extensions import Self from vllm.config.utils import config from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash if TYPE_CHECKING: from vllm.config import ModelConfig @@ -74,7 +74,7 @@ class LoRAConfig: factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @model_validator(mode="after") diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 00a81a319bf72..590bc4dcd0760 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Literal, TypeAlias @@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config +from vllm.utils.hashing import safe_hash if TYPE_CHECKING: from vllm.attention.backends.registry import AttentionBackendEnum @@ -216,7 +216,7 @@ class MultiModalConfig: if self.mm_encoder_attn_backend is not None else None ] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str def get_limit_per_prompt(self, modality: str) -> int: diff --git a/vllm/config/observability.py b/vllm/config/observability.py index 564c4f7aed419..ff35e12fe20ed 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from functools import cached_property from typing import Any, Literal, cast @@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass from vllm import version from vllm.config.utils import config +from vllm.utils.hashing import safe_hash DetailedTraceModules = Literal["model", "worker", "all"] @@ -78,7 +78,7 @@ class ObservabilityConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @field_validator("show_hidden_metrics_for_version") diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 913e97250d3d3..7ba1da5db3849 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -593,9 +593,10 @@ class ParallelConfig: "max_parallel_loading_workers is currently " "not supported and will be ignored." ) - if self.distributed_executor_backend != "mp" and self.nnodes > 1: + if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1: raise ValueError( - "nnodes > 1 can only be set when distributed exectuor backend is mp." + "nnodes > 1 can only be set when distributed executor " + "backend is mp or uni." ) @property diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 6bece8d0785bd..85950bbcd666f 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from typing import Any from pydantic.dataclasses import dataclass from vllm.config.utils import config from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash logger = init_logger(__name__) @@ -102,7 +102,7 @@ class PoolerConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index b6078706daacf..2cf42d57ec217 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from collections.abc import Callable from dataclasses import InitVar from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast @@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated from vllm.config.utils import config from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash from vllm.utils.import_utils import resolve_obj_by_qualname if TYPE_CHECKING: @@ -178,7 +178,7 @@ class SchedulerConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @field_validator("scheduler_cls", "async_scheduling", mode="wrap") diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index d7c019c73d598..80d53a543f149 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -import hashlib from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator @@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash from vllm.utils.import_utils import LazyLoader, has_arctic_inference if TYPE_CHECKING: @@ -162,7 +162,7 @@ class SpeculativeConfig: # Eagle3 affects the computation graph because it returns intermediate # hidden states in addition to the final hidden state. factors.append(self.method == "eagle3") - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @staticmethod diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 9530d3d81e15d..1b32675c3dbd2 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib from typing import Any, Literal from pydantic import model_validator @@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass from typing_extensions import Self from vllm.config.utils import config +from vllm.utils.hashing import safe_hash StructuredOutputsBackend = Literal[ "auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer" @@ -58,7 +58,7 @@ class StructuredOutputsConfig: # no factors to consider. # this config will not affect the computation graph. factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @model_validator(mode="after") diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 8a3599416bc72..9342564aa3d3f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -3,7 +3,6 @@ import copy import getpass -import hashlib import json import os import tempfile @@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid +from vllm.utils.hashing import safe_hash from .cache import CacheConfig from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode @@ -193,7 +193,7 @@ class VllmConfig: vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): - additional_config_hash = hashlib.md5( + additional_config_hash = safe_hash( json.dumps(additional_config, sort_keys=True).encode(), usedforsecurity=False, ).hexdigest() @@ -204,9 +204,9 @@ class VllmConfig: vllm_factors.append("None") factors.append(vllm_factors) - hash_str = hashlib.md5( - str(factors).encode(), usedforsecurity=False - ).hexdigest()[:10] + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ + :10 + ] return hash_str def pad_for_cudagraph(self, batch_size: int) -> int: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 016d1d45b3593..4611b4d1ff7b8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import hashlib import os from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional @@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorRole, ) from vllm.logger import init_logger +from vllm.utils.hashing import safe_hash from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1): if mm_hashes: mm_str = "-".join(mm_hashes) token_bytes += mm_str.encode("utf-8") - input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest() + input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest() foldername = os.path.join(self._storage_path, input_ids_hash) if create_folder: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f81612fd1f4a3..69c28e278f2d2 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -51,6 +51,7 @@ from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.network_utils import get_distributed_init_method +from vllm.utils.system_utils import suppress_stdout from vllm.utils.torch_utils import ( direct_register_custom_op, supports_custom_op, @@ -329,7 +330,8 @@ class GroupCoordinator: ) # a group with `gloo` backend, to allow direct coordination between # processes through the CPU. - cpu_group = torch.distributed.new_group(ranks, backend="gloo") + with suppress_stdout(): + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if self.rank in ranks: self.ranks = ranks self.world_size = len(ranks) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index debf69c49b7d9..242ce393e4dc8 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -30,6 +30,7 @@ from torch.distributed.rendezvous import rendezvous import vllm.envs as envs from vllm.logger import init_logger from vllm.utils.network_utils import get_tcp_uri +from vllm.utils.system_utils import suppress_stdout from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) @@ -427,33 +428,34 @@ def init_gloo_process_group( Stateless init ProcessGroup with gloo backend compatible with different torch versions. """ - if is_torch_equal_or_newer("2.6"): - pg = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - else: - options = ProcessGroup.Options(backend="gloo") - pg = ProcessGroup( - prefix_store, - group_rank, - group_size, - options, - ) - from torch.distributed.distributed_c10d import ProcessGroupGloo + with suppress_stdout(): + if is_torch_equal_or_newer("2.6"): + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + else: + options = ProcessGroup.Options(backend="gloo") + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + options, + ) + from torch.distributed.distributed_c10d import ProcessGroupGloo - backend_class = ProcessGroupGloo( - prefix_store, group_rank, group_size, timeout=timeout - ) - backend_type = ProcessGroup.BackendType.GLOO - device = torch.device("cpu") - if is_torch_equal_or_newer("2.6"): - # _set_default_backend is supported in torch >= 2.6 - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() + backend_class = ProcessGroupGloo( + prefix_store, group_rank, group_size, timeout=timeout + ) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + if is_torch_equal_or_newer("2.6"): + # _set_default_backend is supported in torch >= 2.6 + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() - pg._register_backend(device, backend_type, backend_class) + pg._register_backend(device, backend_type, backend_class) return pg diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index badedfc54c382..128507639fdfd 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP, ) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.import_utils import has_triton_kernels @@ -88,14 +90,17 @@ def triton_kernel_moe_forward( gating_output, topk, sm_first=not renormalize ) + output = torch.empty_like(hidden_states) + return triton_kernel_fused_experts( - None, + output, hidden_states, w1, w2, routing_data, gather_idx, scatter_idx, + topk=topk, activation=activation, quant_config=quant_config, apply_router_weight_on_input=apply_router_weight_on_input, @@ -113,6 +118,7 @@ def triton_kernel_fused_experts( routing_data, # RoutingData gather_indx, # GatherIndx scatter_indx, # ScatterIndx + topk: int, activation: str = "silu", quant_config: FusedMoEQuantConfig | None = None, swiglu_alpha: float = 1.702, @@ -120,6 +126,7 @@ def triton_kernel_fused_experts( apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, + intermediate_cache: torch.Tensor | None = None, a1q_scale: torch.Tensor | None = None, ) -> torch.Tensor: if quant_config is None: @@ -131,14 +138,30 @@ def triton_kernel_fused_experts( assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 # Shape check, only check non-mxfp4 + assert hidden_states.ndim == 2 assert hidden_states.shape[-1] == w1.shape[-2] assert w2.shape[-1] == w1.shape[1] + batch_dim = 1 + M, K = hidden_states.shape[-2:] E, _, N = w1.shape if global_num_experts == -1: global_num_experts = E + if intermediate_cache is None: + intermediate_cache = torch.empty( + (batch_dim, M * topk, N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + # Add batch_dim to output buffer because matmul_ogs expects 3D output + intermediate_cache = _resize_cache( + intermediate_cache, (batch_dim, M * topk, N // 2) + ) + output_tensor = _resize_cache(output_tensor, (batch_dim, M, K)) + act = FusedActivation( FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit), @@ -146,7 +169,7 @@ def triton_kernel_fused_experts( ) gammas = routing_data.gate_scal if routing_data else None - intermediate_cache1 = matmul_ogs( + matmul_ogs( hidden_states, w1, quant_config.w1_bias, @@ -155,10 +178,11 @@ def triton_kernel_fused_experts( precision_config=quant_config.w1_precision, gammas=gammas if apply_router_weight_on_input else None, fused_activation=act, + y=intermediate_cache, ) - intermediate_cache3 = matmul_ogs( - intermediate_cache1, + matmul_ogs( + intermediate_cache.view(M * topk, N // 2), w2, quant_config.w2_bias, routing_data, @@ -167,7 +191,8 @@ def triton_kernel_fused_experts( gammas=None if apply_router_weight_on_input else gammas, y=output_tensor, ) - return intermediate_cache3 + output_tensor = output_tensor.view(M, K) + return output_tensor def make_routing_data( @@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def supports_expert_map(self) -> bool: return True + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + Note: extracting the problem shape from the weight and activation + tensors is not obvious. It needs to be done this way specifically + due to subtle issues with particular kernels, e.g. the int4 kernels + divide the trailing dimension by two, so it's not "correct" to + extract N or K from the trailing dimension of w1 or w2. Similarly, + some kernels transpose the weights, so this needs to be kept in mind. + Note: This implementation covers most cases. However, if experts + require a specialized implementation, like MarlinExperts, they are free + to override this function. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, _, N = w1.size() + K = a1.size(-1) + + assert a1.dim() == 2 + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Weight application and reduction happens in the fused_experts kernel. return TopKWeightAndReduceNoOP() @@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts): expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel - workspace1 = (M, K) - workspace2 = (0, 0) + workspace1 = (0, 0) + workspace2 = (M * topk, N // 2) output = (M, K) return (workspace1, workspace2, output) @@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts): topk_ids, topk_weights, local_num_experts ) - experts_output = triton_kernel_fused_experts( - None, + topk = topk_ids.size(1) + triton_kernel_fused_experts( + output, hidden_states, w1, w2, routing_data, gather_indx, scatter_indx, + topk=topk, activation=activation, quant_config=self.quant_config, apply_router_weight_on_input=False, global_num_experts=local_num_experts, expert_map=None, # applied already + intermediate_cache=workspace2, a1q_scale=a1q_scale, ) - - output.copy_(experts_output, non_blocking=True) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a0d8a78a2ae76..53644f9cb8788 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update `tests/models/registry.py` with example HuggingFace models for it. """ -import hashlib import importlib import json import os @@ -32,6 +31,7 @@ from vllm.config import ( from vllm.logger import init_logger from vllm.logging_utils import logtime from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module +from vllm.utils.hashing import safe_hash from .interfaces import ( has_inner_state, @@ -654,7 +654,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel): if model_path.exists(): with open(model_path, "rb") as f: - module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest() + module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest() mi = self._load_modelinfo_from_cache(module_hash) if mi is not None: diff --git a/vllm/utils/hashing.py b/vllm/utils/hashing.py index 49f4f13d115f3..edf1e9cb34e56 100644 --- a/vllm/utils/hashing.py +++ b/vllm/utils/hashing.py @@ -5,6 +5,7 @@ from __future__ import annotations import hashlib import pickle +from _hashlib import HASH, UnsupportedDigestmodError from collections.abc import Callable from typing import Any @@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: return sha256_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") + + +def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH: + """Hash for configs, defaulting to md5 but falling back to sha256 + in FIPS constrained environments. + + Args: + data: bytes + usedforsecurity: Whether the hash is used for security purposes + + Returns: + Hash object + """ + try: + return hashlib.md5(data, usedforsecurity=usedforsecurity) + except (UnsupportedDigestmodError, ValueError): + return hashlib.sha256(data) diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py index cc872040b6c5f..a4eb8f4d4fd7d 100644 --- a/vllm/utils/system_utils.py +++ b/vllm/utils/system_utils.py @@ -56,6 +56,39 @@ def set_env_var(key: str, value: str) -> Iterator[None]: os.environ[key] = old +@contextlib.contextmanager +def suppress_stdout(): + """ + Suppress stdout from C libraries at the file descriptor level. + + Only suppresses stdout, not stderr, to preserve error messages. + Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG. + + Example: + with suppress_stdout(): + # C library calls that would normally print to stdout + torch.distributed.new_group(ranks, backend="gloo") + """ + # Don't suppress if logging level is DEBUG + if envs.VLLM_LOGGING_LEVEL == "DEBUG": + yield + return + + stdout_fd = sys.stdout.fileno() + stdout_dup = os.dup(stdout_fd) + devnull_fd = os.open(os.devnull, os.O_WRONLY) + + try: + sys.stdout.flush() + os.dup2(devnull_fd, stdout_fd) + yield + finally: + sys.stdout.flush() + os.dup2(stdout_dup, stdout_fd) + os.close(stdout_dup) + os.close(devnull_fd) + + # File path utilities diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 87a3aac21d2c3..d94ed9183f639 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if self.is_aiter_triton_fp8_bmm_enabled: + out = out.view(-1, self.num_heads, self.v_head_dim) # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True + x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out ) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - # Copy result - out.copy_(x) else: # Convert from (B, N * V) to (N, B, V) out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1) @@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, - ) -> torch.Tensor: + output: torch.Tensor, + ) -> None: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size is not None @@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = self._run_prefill_new_tokens( + output_prefill = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, q=q, k=k, @@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) if has_context: - suffix_output, suffix_lse = output + suffix_output, suffix_lse = output_prefill if self.dcp_world_size > 1: context_output, context_lse = ( self._context_parallel_compute_prefill_context( @@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q, kv_c_and_k_pe_cache, attn_metadata, k_scale ) - output = torch.empty_like(suffix_output) + # unpad if necessary + if self._pad_v: + context_output = context_output[..., : v.shape[-1]] + suffix_output = suffix_output[..., : v.shape[-1]] + + output = output.view(-1, self.num_heads, self.v_head_dim) merge_attn_states( output=output, prefix_output=context_output, @@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): suffix_output=suffix_output, suffix_lse=suffix_lse, ) - - # unpad if necessary - if self._pad_v: - output = output[..., : v.shape[-1]] - - return output.flatten(start_dim=-2) + else: + output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) + output.copy_(output_prefill) @abstractmethod def _forward_decode( @@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_cache = kv_cache.view(current_platform.fp8_dtype()) if has_prefill: - output[num_decode_tokens:] = self._forward_prefill( + self._forward_prefill( prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, attn_metadata, layer._k_scale, + output=output[num_decode_tokens:], ) if has_decode: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bb44c5ad84cc1..d3c61794f8b0d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2789,7 +2789,7 @@ class GPUModelRunner( # returns True. before returning early here we call # dummy run to ensure coordinate_batch_across_dp # is called into to avoid out of sync issues. - self._dummy_run(1) + self._dummy_run(self._get_num_input_tokens(1)) if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT @@ -3460,6 +3460,10 @@ class GPUModelRunner( scope="local", ) prepare_communication_buffer_for_model(self.model) + if (drafter := getattr(self, "drafter", None)) and ( + drafter_model := getattr(drafter, "model", None) + ): + prepare_communication_buffer_for_model(drafter_model) mm_config = self.model_config.multimodal_config self.is_multimodal_pruning_enabled = ( supports_multimodal_pruning(self.get_model())