Merge branch 'main' into pynccl_symm_fix

This commit is contained in:
Amir Samani 2025-11-26 02:06:05 -08:00 committed by GitHub
commit 2d9628c411
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 252 additions and 122 deletions

View File

@ -13,7 +13,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- name: Set up Python
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0

View File

@ -12,7 +12,7 @@ jobs:
timeout-minutes: 30
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v7
with:

View File

@ -16,7 +16,7 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: "3.12"

View File

@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size) {
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride) {
using pack_128b_t = uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
const uint head_idx = token_head_idx % num_heads;
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint head_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
scalar_t* output_head_ptr = output + head_offset;
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
head_idx * prefix_head_stride;
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size); \
num_heads, head_size, prefix_head_stride, output_head_stride); \
}
/*@brief Merges the attention states from prefix and suffix
@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
"output heads must be contiguous in memory");
TORCH_CHECK(
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
"prefix_output heads must be contiguous in memory");
TORCH_CHECK(
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
"suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();

View File

@ -52,14 +52,13 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]

View File

@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "

View File

@ -20,7 +20,11 @@ def merge_attn_states(
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# We assume the output stride on num_head is not always as same as the
# `suffix_output` and `prefix_output`, as them might be padded by the attention
# backend.
prefix_head_stride = prefix_output.stride(1)
output_head_stride = output.stride(1)
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
merge_attn_states_kernel[(num_tokens, num_query_heads)](
output,
@ -29,6 +33,8 @@ def merge_attn_states(
prefix_lse,
suffix_output,
suffix_lse,
prefix_head_stride,
output_head_stride,
head_size,
padded_head_size,
output_lse is not None,
@ -43,6 +49,8 @@ def merge_attn_states_kernel(
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_head_stride,
output_head_stride,
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
@ -79,15 +87,15 @@ def merge_attn_states_kernel(
head_mask = head_arange < HEAD_SIZE
p_out = tl.load(
prefix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
s_out = tl.load(
suffix_output
+ token_idx * num_heads * HEAD_SIZE
+ head_idx * HEAD_SIZE
+ token_idx * num_heads * prefix_head_stride
+ head_idx * prefix_head_stride
+ head_arange,
mask=head_mask,
)
@ -99,7 +107,10 @@ def merge_attn_states_kernel(
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
tl.store(
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
output
+ token_idx * num_heads * output_head_stride
+ head_idx * output_head_stride
+ head_arange,
out,
mask=head_mask,
)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect
import os
import pickle
@ -14,6 +13,7 @@ import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
try:
from torch._dynamo.aot_compile import SerializableCallable
@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
return hashlib.md5(
return safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import hashlib
import os
from collections.abc import Callable
from contextlib import ExitStack
@ -16,6 +15,7 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
def initialize_cache(
@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
def initialize_cache(

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import Any, Literal
@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@ -45,7 +45,7 @@ class DeviceConfig:
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import uuid
from dataclasses import field
from typing import Any, Literal, get_args
@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
KVProducer = Literal["kv_producer", "kv_both"]
KVConsumer = Literal["kv_consumer", "kv_both"]
@ -79,7 +79,7 @@ class KVTransferConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self) -> None:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any
from pydantic import Field, field_validator
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.model_executor.model_loader import LoadFormats
@ -104,7 +104,7 @@ class LoadConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("load_format", mode="after")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import TYPE_CHECKING, Any, Literal
import torch
@ -11,6 +10,7 @@ from typing_extensions import Self
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.config import ModelConfig
@ -74,7 +74,7 @@ class LoRAConfig:
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
@ -216,7 +216,7 @@ class MultiModalConfig:
if self.mm_encoder_attn_backend is not None
else None
]
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def get_limit_per_prompt(self, modality: str) -> int:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from functools import cached_property
from typing import Any, Literal, cast
@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
DetailedTraceModules = Literal["model", "worker", "all"]
@ -78,7 +78,7 @@ class ObservabilityConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("show_hidden_metrics_for_version")

View File

@ -593,9 +593,10 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
raise ValueError(
"nnodes > 1 can only be set when distributed exectuor backend is mp."
"nnodes > 1 can only be set when distributed executor "
"backend is mp or uni."
)
@property

View File

@ -1,13 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
logger = init_logger(__name__)
@ -102,7 +102,7 @@ class PoolerConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
@ -178,7 +178,7 @@ class SchedulerConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING:
@ -162,7 +162,7 @@ class SpeculativeConfig:
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@staticmethod

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Literal
from pydantic import model_validator
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
StructuredOutputsBackend = Literal[
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
@ -58,7 +58,7 @@ class StructuredOutputsConfig:
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@model_validator(mode="after")

View File

@ -3,7 +3,6 @@
import copy
import getpass
import hashlib
import json
import os
import tempfile
@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
from vllm.utils.hashing import safe_hash
from .cache import CacheConfig
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
@ -193,7 +193,7 @@ class VllmConfig:
vllm_factors.append("None")
if self.additional_config:
if isinstance(additional_config := self.additional_config, dict):
additional_config_hash = hashlib.md5(
additional_config_hash = safe_hash(
json.dumps(additional_config, sort_keys=True).encode(),
usedforsecurity=False,
).hexdigest()
@ -204,9 +204,9 @@ class VllmConfig:
vllm_factors.append("None")
factors.append(vllm_factors)
hash_str = hashlib.md5(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
return hash_str
def pad_for_cudagraph(self, batch_size: int) -> int:

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
if mm_hashes:
mm_str = "-".join(mm_hashes)
token_bytes += mm_str.encode("utf-8")
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
foldername = os.path.join(self._storage_path, input_ids_hash)
if create_folder:

View File

@ -51,6 +51,7 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.network_utils import get_distributed_init_method
from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import (
direct_register_custom_op,
supports_custom_op,
@ -329,7 +330,8 @@ class GroupCoordinator:
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
with suppress_stdout():
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)

View File

@ -30,6 +30,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
@ -427,33 +428,34 @@ def init_gloo_process_group(
Stateless init ProcessGroup with gloo backend compatible with
different torch versions.
"""
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
with suppress_stdout():
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(
prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
backend_class = ProcessGroupGloo(
prefix_store, group_rank, group_size, timeout=timeout
)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
pg._register_backend(device, backend_type, backend_class)
return pg

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output, topk, sm_first=not renormalize
)
output = torch.empty_like(hidden_states)
return triton_kernel_fused_experts(
None,
output,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
topk=topk,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
topk: int,
activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
intermediate_cache: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if quant_config is None:
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
# Shape check, only check non-mxfp4
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]
batch_dim = 1
M, K = hidden_states.shape[-2:]
E, _, N = w1.shape
if global_num_experts == -1:
global_num_experts = E
if intermediate_cache is None:
intermediate_cache = torch.empty(
(batch_dim, M * topk, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache = _resize_cache(
intermediate_cache, (batch_dim, M * topk, N // 2)
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
)
gammas = routing_data.gate_scal if routing_data else None
intermediate_cache1 = matmul_ogs(
matmul_ogs(
hidden_states,
w1,
quant_config.w1_bias,
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
y=intermediate_cache,
)
intermediate_cache3 = matmul_ogs(
intermediate_cache1,
matmul_ogs(
intermediate_cache.view(M * topk, N // 2),
w2,
quant_config.w2_bias,
routing_data,
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
return intermediate_cache3
output_tensor = output_tensor.view(M, K)
return output_tensor
def make_routing_data(
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def supports_expert_map(self) -> bool:
return True
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, _, N = w1.size()
K = a1.size(-1)
assert a1.dim() == 2
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M, K)
workspace2 = (0, 0)
workspace1 = (0, 0)
workspace2 = (M * topk, N // 2)
output = (M, K)
return (workspace1, workspace2, output)
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids, topk_weights, local_num_experts
)
experts_output = triton_kernel_fused_experts(
None,
topk = topk_ids.size(1)
triton_kernel_fused_experts(
output,
hidden_states,
w1,
w2,
routing_data,
gather_indx,
scatter_indx,
topk=topk,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
global_num_experts=local_num_experts,
expert_map=None, # applied already
intermediate_cache=workspace2,
a1q_scale=a1q_scale,
)
output.copy_(experts_output, non_blocking=True)

View File

@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
`tests/models/registry.py` with example HuggingFace models for it.
"""
import hashlib
import importlib
import json
import os
@ -32,6 +31,7 @@ from vllm.config import (
from vllm.logger import init_logger
from vllm.logging_utils import logtime
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
from vllm.utils.hashing import safe_hash
from .interfaces import (
has_inner_state,
@ -654,7 +654,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
if model_path.exists():
with open(model_path, "rb") as f:
module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
mi = self._load_modelinfo_from_cache(module_hash)
if mi is not None:

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import hashlib
import pickle
from _hashlib import HASH, UnsupportedDigestmodError
from collections.abc import Callable
from typing import Any
@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
return sha256_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH:
"""Hash for configs, defaulting to md5 but falling back to sha256
in FIPS constrained environments.
Args:
data: bytes
usedforsecurity: Whether the hash is used for security purposes
Returns:
Hash object
"""
try:
return hashlib.md5(data, usedforsecurity=usedforsecurity)
except (UnsupportedDigestmodError, ValueError):
return hashlib.sha256(data)

View File

@ -56,6 +56,39 @@ def set_env_var(key: str, value: str) -> Iterator[None]:
os.environ[key] = old
@contextlib.contextmanager
def suppress_stdout():
"""
Suppress stdout from C libraries at the file descriptor level.
Only suppresses stdout, not stderr, to preserve error messages.
Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG.
Example:
with suppress_stdout():
# C library calls that would normally print to stdout
torch.distributed.new_group(ranks, backend="gloo")
"""
# Don't suppress if logging level is DEBUG
if envs.VLLM_LOGGING_LEVEL == "DEBUG":
yield
return
stdout_fd = sys.stdout.fileno()
stdout_dup = os.dup(stdout_fd)
devnull_fd = os.open(os.devnull, os.O_WRONLY)
try:
sys.stdout.flush()
os.dup2(devnull_fd, stdout_fd)
yield
finally:
sys.stdout.flush()
os.dup2(stdout_dup, stdout_fd)
os.close(stdout_dup)
os.close(devnull_fd)
# File path utilities

View File

@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if self.is_aiter_triton_fp8_bmm_enabled:
out = out.view(-1, self.num_heads, self.v_head_dim)
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = rocm_aiter_ops.triton_fp8_bmm(
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
k_scale: torch.Tensor,
) -> torch.Tensor:
output: torch.Tensor,
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._run_prefill_new_tokens(
output_prefill = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q,
k=k,
@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
if has_context:
suffix_output, suffix_lse = output
suffix_output, suffix_lse = output_prefill
if self.dcp_world_size > 1:
context_output, context_lse = (
self._context_parallel_compute_prefill_context(
@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
)
output = torch.empty_like(suffix_output)
# unpad if necessary
if self._pad_v:
context_output = context_output[..., : v.shape[-1]]
suffix_output = suffix_output[..., : v.shape[-1]]
output = output.view(-1, self.num_heads, self.v_head_dim)
merge_attn_states(
output=output,
prefix_output=context_output,
@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# unpad if necessary
if self._pad_v:
output = output[..., : v.shape[-1]]
return output.flatten(start_dim=-2)
else:
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
output.copy_(output_prefill)
@abstractmethod
def _forward_decode(
@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_cache = kv_cache.view(current_platform.fp8_dtype())
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
self._forward_prefill(
prefill_q,
prefill_k_c_normed,
prefill_k_pe,
kv_cache,
attn_metadata,
layer._k_scale,
output=output[num_decode_tokens:],
)
if has_decode:

View File

@ -2789,7 +2789,7 @@ class GPUModelRunner(
# returns True. before returning early here we call
# dummy run to ensure coordinate_batch_across_dp
# is called into to avoid out of sync issues.
self._dummy_run(1)
self._dummy_run(self._get_num_input_tokens(1))
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
@ -3460,6 +3460,10 @@ class GPUModelRunner(
scope="local",
)
prepare_communication_buffer_for_model(self.model)
if (drafter := getattr(self, "drafter", None)) and (
drafter_model := getattr(drafter, "model", None)
):
prepare_communication_buffer_for_model(drafter_model)
mm_config = self.model_config.multimodal_config
self.is_multimodal_pruning_enabled = (
supports_multimodal_pruning(self.get_model())