mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 09:07:16 +08:00
Merge branch 'main' into pynccl_symm_fix
This commit is contained in:
commit
2d9628c411
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
|
||||
2
.github/workflows/macos-smoke-test.yml
vendored
2
.github/workflows/macos-smoke-test.yml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
timeout-minutes: 30
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
|
||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
|
||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
@ -16,7 +16,8 @@ __global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size) {
|
||||
const uint head_size, const uint prefix_head_stride,
|
||||
const uint output_head_stride) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
@ -34,11 +35,13 @@ __global__ void merge_attn_states_kernel(
|
||||
const uint head_idx = token_head_idx % num_heads;
|
||||
|
||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||
const uint head_offset =
|
||||
token_idx * num_heads * head_size + head_idx * head_size;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||
scalar_t* output_head_ptr = output + head_offset;
|
||||
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
|
||||
head_idx * prefix_head_stride;
|
||||
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
|
||||
head_idx * output_head_stride;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||
|
||||
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
@ -140,7 +143,7 @@ __global__ void merge_attn_states_kernel(
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size); \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
@ -166,17 +169,11 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint prefix_head_stride = prefix_output.stride(1);
|
||||
const uint output_head_stride = output.stride(1);
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
|
||||
"output heads must be contiguous in memory");
|
||||
TORCH_CHECK(
|
||||
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
|
||||
"prefix_output heads must be contiguous in memory");
|
||||
TORCH_CHECK(
|
||||
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
|
||||
"suffix_output heads must be contiguous in memory");
|
||||
float* output_lse_ptr = nullptr;
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
|
||||
@ -52,14 +52,13 @@ void paged_attention_v2(
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
|
||||
|
||||
@ -63,7 +63,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Merge attn states
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
@ -76,7 +75,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
"convert_vertical_slash_indexes("
|
||||
" Tensor! block_count, Tensor! block_offset, "
|
||||
|
||||
@ -20,7 +20,11 @@ def merge_attn_states(
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
|
||||
# We assume the output stride on num_head is not always as same as the
|
||||
# `suffix_output` and `prefix_output`, as them might be padded by the attention
|
||||
# backend.
|
||||
prefix_head_stride = prefix_output.stride(1)
|
||||
output_head_stride = output.stride(1)
|
||||
# TODO(woosuk): Use CUDA kernel instead of Triton to minimize CPU overhead.
|
||||
merge_attn_states_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
@ -29,6 +33,8 @@ def merge_attn_states(
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
prefix_head_stride,
|
||||
output_head_stride,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
@ -43,6 +49,8 @@ def merge_attn_states_kernel(
|
||||
prefix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_head_stride,
|
||||
output_head_stride,
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
@ -79,15 +87,15 @@ def merge_attn_states_kernel(
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(
|
||||
prefix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ token_idx * num_heads * prefix_head_stride
|
||||
+ head_idx * prefix_head_stride
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
s_out = tl.load(
|
||||
suffix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ token_idx * num_heads * prefix_head_stride
|
||||
+ head_idx * prefix_head_stride
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
@ -99,7 +107,10 @@ def merge_attn_states_kernel(
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(
|
||||
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
||||
output
|
||||
+ token_idx * num_heads * output_head_stride
|
||||
+ head_idx * output_head_stride
|
||||
+ head_arange,
|
||||
out,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
@ -14,6 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
try:
|
||||
from torch._dynamo.aot_compile import SerializableCallable
|
||||
@ -160,7 +160,7 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
return hashlib.md5(
|
||||
return safe_hash(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
@ -16,6 +15,7 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
@ -197,9 +197,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
@ -286,9 +286,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -10,6 +9,7 @@ from pydantic import ConfigDict, SkipValidation
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
@ -45,7 +45,7 @@ class DeviceConfig:
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
@ -9,6 +8,7 @@ from typing import Any, Literal, get_args
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
@ -79,7 +79,7 @@ class KVTransferConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
@ -104,7 +104,7 @@ class LoadConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("load_format", mode="after")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
@ -11,6 +10,7 @@ from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@ -74,7 +74,7 @@ class LoRAConfig:
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
|
||||
|
||||
@ -9,6 +8,7 @@ from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
@ -216,7 +216,7 @@ class MultiModalConfig:
|
||||
if self.mm_encoder_attn_backend is not None
|
||||
else None
|
||||
]
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
@ -11,6 +10,7 @@ from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
@ -78,7 +78,7 @@ class ObservabilityConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("show_hidden_metrics_for_version")
|
||||
|
||||
@ -593,9 +593,10 @@ class ParallelConfig:
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
|
||||
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed exectuor backend is mp."
|
||||
"nnodes > 1 can only be set when distributed executor "
|
||||
"backend is mp or uni."
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -102,7 +102,7 @@ class PoolerConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||
@ -12,6 +11,7 @@ from typing_extensions import Self, deprecated
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -178,7 +178,7 @@ class SchedulerConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
@ -13,6 +12,7 @@ from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -162,7 +162,7 @@ class SpeculativeConfig:
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
@ -9,6 +8,7 @@ from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
StructuredOutputsBackend = Literal[
|
||||
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
|
||||
@ -58,7 +58,7 @@ class StructuredOutputsConfig:
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
import copy
|
||||
import getpass
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@ -25,6 +24,7 @@ from vllm.config.speculative import EagleModelTypes
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .cache import CacheConfig
|
||||
from .compilation import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
@ -193,7 +193,7 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
if self.additional_config:
|
||||
if isinstance(additional_config := self.additional_config, dict):
|
||||
additional_config_hash = hashlib.md5(
|
||||
additional_config_hash = safe_hash(
|
||||
json.dumps(additional_config, sort_keys=True).encode(),
|
||||
usedforsecurity=False,
|
||||
).hexdigest()
|
||||
@ -204,9 +204,9 @@ class VllmConfig:
|
||||
vllm_factors.append("None")
|
||||
factors.append(vllm_factors)
|
||||
|
||||
hash_str = hashlib.md5(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
return hash_str
|
||||
|
||||
def pad_for_cudagraph(self, batch_size: int) -> int:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@ -15,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@ -423,7 +423,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
if mm_hashes:
|
||||
mm_str = "-".join(mm_hashes)
|
||||
token_bytes += mm_str.encode("utf-8")
|
||||
input_ids_hash = hashlib.md5(token_bytes, usedforsecurity=False).hexdigest()
|
||||
input_ids_hash = safe_hash(token_bytes, usedforsecurity=False).hexdigest()
|
||||
|
||||
foldername = os.path.join(self._storage_path, input_ids_hash)
|
||||
if create_folder:
|
||||
|
||||
@ -51,6 +51,7 @@ from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.network_utils import get_distributed_init_method
|
||||
from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
supports_custom_op,
|
||||
@ -329,7 +330,8 @@ class GroupCoordinator:
|
||||
)
|
||||
# a group with `gloo` backend, to allow direct coordination between
|
||||
# processes through the CPU.
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
with suppress_stdout():
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
|
||||
@ -30,6 +30,7 @@ from torch.distributed.rendezvous import rendezvous
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -427,33 +428,34 @@ def init_gloo_process_group(
|
||||
Stateless init ProcessGroup with gloo backend compatible with
|
||||
different torch versions.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
else:
|
||||
options = ProcessGroup.Options(backend="gloo")
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
options,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
with suppress_stdout():
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
else:
|
||||
options = ProcessGroup.Options(backend="gloo")
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
options,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
|
||||
backend_class = ProcessGroupGloo(
|
||||
prefix_store, group_rank, group_size, timeout=timeout
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
# _set_default_backend is supported in torch >= 2.6
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
backend_class = ProcessGroupGloo(
|
||||
prefix_store, group_rank, group_size, timeout=timeout
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
# _set_default_backend is supported in torch >= 2.6
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
|
||||
return triton_kernel_fused_experts(
|
||||
None,
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
topk: int,
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
intermediate_cache: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.ndim == 2
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
batch_dim = 1
|
||||
M, K = hidden_states.shape[-2:]
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
if intermediate_cache is None:
|
||||
intermediate_cache = torch.empty(
|
||||
(batch_dim, M * topk, N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
# Add batch_dim to output buffer because matmul_ogs expects 3D output
|
||||
intermediate_cache = _resize_cache(
|
||||
intermediate_cache, (batch_dim, M * topk, N // 2)
|
||||
)
|
||||
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))
|
||||
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
|
||||
)
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
quant_config.w1_bias,
|
||||
@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
|
||||
precision_config=quant_config.w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=act,
|
||||
y=intermediate_cache,
|
||||
)
|
||||
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
matmul_ogs(
|
||||
intermediate_cache.view(M * topk, N // 2),
|
||||
w2,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=output_tensor,
|
||||
)
|
||||
return intermediate_cache3
|
||||
output_tensor = output_tensor.view(M, K)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def make_routing_data(
|
||||
@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def moe_problem_size(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> tuple[int, int, int, int, int]:
|
||||
"""
|
||||
Extract the MoE problem size from the given tensor arguments:
|
||||
- a: The hidden states, input to the MoE layer.
|
||||
- w1: The first set of expert weights.
|
||||
- w2: The second set of expert weights.
|
||||
- topk_ids: The topk ids.
|
||||
Note: extracting the problem shape from the weight and activation
|
||||
tensors is not obvious. It needs to be done this way specifically
|
||||
due to subtle issues with particular kernels, e.g. the int4 kernels
|
||||
divide the trailing dimension by two, so it's not "correct" to
|
||||
extract N or K from the trailing dimension of w1 or w2. Similarly,
|
||||
some kernels transpose the weights, so this needs to be kept in mind.
|
||||
Note: This implementation covers most cases. However, if experts
|
||||
require a specialized implementation, like MarlinExperts, they are free
|
||||
to override this function.
|
||||
"""
|
||||
assert w1.dim() == 3 and w2.dim() == 3
|
||||
E, _, N = w1.size()
|
||||
K = a1.size(-1)
|
||||
|
||||
assert a1.dim() == 2
|
||||
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
|
||||
M = a1.size(0)
|
||||
|
||||
assert topk_ids.dim() == 2
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
return E, M, N, K, topk
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Weight application and reduction happens in the fused_experts kernel.
|
||||
return TopKWeightAndReduceNoOP()
|
||||
@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# workspace are allocated inside the kernel
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
workspace1 = (0, 0)
|
||||
workspace2 = (M * topk, N // 2)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
|
||||
topk_ids, topk_weights, local_num_experts
|
||||
)
|
||||
|
||||
experts_output = triton_kernel_fused_experts(
|
||||
None,
|
||||
topk = topk_ids.size(1)
|
||||
triton_kernel_fused_experts(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_indx,
|
||||
scatter_indx,
|
||||
topk=topk,
|
||||
activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
apply_router_weight_on_input=False,
|
||||
global_num_experts=local_num_experts,
|
||||
expert_map=None, # applied already
|
||||
intermediate_cache=workspace2,
|
||||
a1q_scale=a1q_scale,
|
||||
)
|
||||
|
||||
output.copy_(experts_output, non_blocking=True)
|
||||
|
||||
@ -5,7 +5,6 @@ Whenever you add an architecture to this page, please also update
|
||||
`tests/models/registry.py` with example HuggingFace models for it.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
@ -32,6 +31,7 @@ from vllm.config import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import logtime
|
||||
from vllm.transformers_utils.dynamic_module import try_get_class_from_dynamic_module
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
from .interfaces import (
|
||||
has_inner_state,
|
||||
@ -654,7 +654,7 @@ class _LazyRegisteredModel(_BaseRegisteredModel):
|
||||
|
||||
if model_path.exists():
|
||||
with open(model_path, "rb") as f:
|
||||
module_hash = hashlib.md5(f.read(), usedforsecurity=False).hexdigest()
|
||||
module_hash = safe_hash(f.read(), usedforsecurity=False).hexdigest()
|
||||
|
||||
mi = self._load_modelinfo_from_cache(module_hash)
|
||||
if mi is not None:
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
from _hashlib import HASH, UnsupportedDigestmodError
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@ -61,3 +62,20 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||
return sha256_cbor
|
||||
|
||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||
|
||||
|
||||
def safe_hash(data: bytes, usedforsecurity: bool = True) -> HASH:
|
||||
"""Hash for configs, defaulting to md5 but falling back to sha256
|
||||
in FIPS constrained environments.
|
||||
|
||||
Args:
|
||||
data: bytes
|
||||
usedforsecurity: Whether the hash is used for security purposes
|
||||
|
||||
Returns:
|
||||
Hash object
|
||||
"""
|
||||
try:
|
||||
return hashlib.md5(data, usedforsecurity=usedforsecurity)
|
||||
except (UnsupportedDigestmodError, ValueError):
|
||||
return hashlib.sha256(data)
|
||||
|
||||
@ -56,6 +56,39 @@ def set_env_var(key: str, value: str) -> Iterator[None]:
|
||||
os.environ[key] = old
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_stdout():
|
||||
"""
|
||||
Suppress stdout from C libraries at the file descriptor level.
|
||||
|
||||
Only suppresses stdout, not stderr, to preserve error messages.
|
||||
Suppression is disabled when VLLM_LOGGING_LEVEL is set to DEBUG.
|
||||
|
||||
Example:
|
||||
with suppress_stdout():
|
||||
# C library calls that would normally print to stdout
|
||||
torch.distributed.new_group(ranks, backend="gloo")
|
||||
"""
|
||||
# Don't suppress if logging level is DEBUG
|
||||
if envs.VLLM_LOGGING_LEVEL == "DEBUG":
|
||||
yield
|
||||
return
|
||||
|
||||
stdout_fd = sys.stdout.fileno()
|
||||
stdout_dup = os.dup(stdout_fd)
|
||||
devnull_fd = os.open(os.devnull, os.O_WRONLY)
|
||||
|
||||
try:
|
||||
sys.stdout.flush()
|
||||
os.dup2(devnull_fd, stdout_fd)
|
||||
yield
|
||||
finally:
|
||||
sys.stdout.flush()
|
||||
os.dup2(stdout_dup, stdout_fd)
|
||||
os.close(stdout_dup)
|
||||
os.close(devnull_fd)
|
||||
|
||||
|
||||
# File path utilities
|
||||
|
||||
|
||||
|
||||
@ -1238,15 +1238,13 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
if self.is_aiter_triton_fp8_bmm_enabled:
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim)
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = rocm_aiter_ops.triton_fp8_bmm(
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
|
||||
x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
|
||||
)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
# Copy result
|
||||
out.copy_(x)
|
||||
else:
|
||||
# Convert from (B, N * V) to (N, B, V)
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||
@ -1824,7 +1822,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output: torch.Tensor,
|
||||
) -> None:
|
||||
# TODO (zyongye): Prefill function here
|
||||
assert attn_metadata.prefill is not None
|
||||
assert self.dcp_world_size is not None
|
||||
@ -1837,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
output = self._run_prefill_new_tokens(
|
||||
output_prefill = self._run_prefill_new_tokens(
|
||||
prefill=attn_metadata.prefill,
|
||||
q=q,
|
||||
k=k,
|
||||
@ -1846,7 +1845,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
)
|
||||
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
suffix_output, suffix_lse = output_prefill
|
||||
if self.dcp_world_size > 1:
|
||||
context_output, context_lse = (
|
||||
self._context_parallel_compute_prefill_context(
|
||||
@ -1862,7 +1861,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
q, kv_c_and_k_pe_cache, attn_metadata, k_scale
|
||||
)
|
||||
|
||||
output = torch.empty_like(suffix_output)
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
context_output = context_output[..., : v.shape[-1]]
|
||||
suffix_output = suffix_output[..., : v.shape[-1]]
|
||||
|
||||
output = output.view(-1, self.num_heads, self.v_head_dim)
|
||||
merge_attn_states(
|
||||
output=output,
|
||||
prefix_output=context_output,
|
||||
@ -1870,12 +1874,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
suffix_output=suffix_output,
|
||||
suffix_lse=suffix_lse,
|
||||
)
|
||||
|
||||
# unpad if necessary
|
||||
if self._pad_v:
|
||||
output = output[..., : v.shape[-1]]
|
||||
|
||||
return output.flatten(start_dim=-2)
|
||||
else:
|
||||
output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
|
||||
output.copy_(output_prefill)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
@ -1970,13 +1971,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
kv_cache = kv_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
if has_prefill:
|
||||
output[num_decode_tokens:] = self._forward_prefill(
|
||||
self._forward_prefill(
|
||||
prefill_q,
|
||||
prefill_k_c_normed,
|
||||
prefill_k_pe,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
layer._k_scale,
|
||||
output=output[num_decode_tokens:],
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
|
||||
@ -2789,7 +2789,7 @@ class GPUModelRunner(
|
||||
# returns True. before returning early here we call
|
||||
# dummy run to ensure coordinate_batch_across_dp
|
||||
# is called into to avoid out of sync issues.
|
||||
self._dummy_run(1)
|
||||
self._dummy_run(self._get_num_input_tokens(1))
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
@ -3460,6 +3460,10 @@ class GPUModelRunner(
|
||||
scope="local",
|
||||
)
|
||||
prepare_communication_buffer_for_model(self.model)
|
||||
if (drafter := getattr(self, "drafter", None)) and (
|
||||
drafter_model := getattr(drafter, "model", None)
|
||||
):
|
||||
prepare_communication_buffer_for_model(drafter_model)
|
||||
mm_config = self.model_config.multimodal_config
|
||||
self.is_multimodal_pruning_enabled = (
|
||||
supports_multimodal_pruning(self.get_model())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user