mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:24:56 +08:00
Kernel-override Determinism [1/n] (#25603)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
parent
4778b42660
commit
dc48ba0c75
16
csrc/core/batch_invariant.hpp
Normal file
16
csrc/core/batch_invariant.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// vllm_kernel_override_batch_invariant(); returns true
|
||||
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||
inline bool vllm_kernel_override_batch_invariant() {
|
||||
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||
const char* val = std::getenv(env_key.c_str());
|
||||
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,6 +1,7 @@
|
||||
#include "type_convert.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include "quantization/fp8/common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
#include "../cub_helpers.h"
|
||||
#include "../core/batch_invariant.hpp"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
|
||||
290
tests/v1/generation/test_batch_invariance.py
Normal file
290
tests/v1/generation/test_batch_invariance.py
Normal file
@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Lightweight random prompt generator to vary prompt lengths and content.
|
||||
vocab = [
|
||||
"alpha",
|
||||
"bravo",
|
||||
"charlie",
|
||||
"delta",
|
||||
"echo",
|
||||
"foxtrot",
|
||||
"golf",
|
||||
"hotel",
|
||||
"india",
|
||||
"juliet",
|
||||
"kilo",
|
||||
"lima",
|
||||
"mike",
|
||||
"november",
|
||||
"oscar",
|
||||
"papa",
|
||||
"quebec",
|
||||
"romeo",
|
||||
"sierra",
|
||||
"tango",
|
||||
"uniform",
|
||||
"victor",
|
||||
"whiskey",
|
||||
"xray",
|
||||
"yankee",
|
||||
"zulu",
|
||||
]
|
||||
n = random.randint(min_words, max_words)
|
||||
words = random.choices(vocab, k=n)
|
||||
|
||||
# Add some noise and punctuation variability
|
||||
if random.random() < 0.5:
|
||||
words[0] = words[0].capitalize()
|
||||
if random.random() < 0.2:
|
||||
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
|
||||
punct = random.choice([".", "?", "!", "...", ""])
|
||||
return " ".join(words) + punct
|
||||
|
||||
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
using the high-level v1 LLM() API only (no manual batching).
|
||||
|
||||
Strategy:
|
||||
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
|
||||
- Compute a baseline output for the needle prompt with the bs=1 engine.
|
||||
- For many trials, generate a batch (size N) where the needle appears at a
|
||||
random position among random filler prompts using the bs=N engine.
|
||||
- Track how many trials match vs mismatch, and report totals at the end.
|
||||
The test fails if any mismatches occur, but we still dump pass/fail
|
||||
counts.
|
||||
|
||||
Notes:
|
||||
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
||||
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
||||
to produce a more random-sounding phrase, yet remain deterministic by
|
||||
seed.
|
||||
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||
"""
|
||||
random.seed(12345)
|
||||
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
|
||||
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
||||
|
||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
|
||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
|
||||
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
||||
|
||||
# Sampling parameters: longer outputs with a more random-sounding
|
||||
# continuation,but still deterministic due to fixed seed.
|
||||
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
|
||||
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
|
||||
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
seed=20240919,
|
||||
)
|
||||
|
||||
needle_prompt = ("There once was a ")
|
||||
|
||||
llm_bs1 = None
|
||||
llm_bsN = None
|
||||
try:
|
||||
# Engine with bs=1 behavior
|
||||
llm_bs1 = LLM_with_max_seqs(
|
||||
model=model,
|
||||
max_num_seqs=1,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
baseline_out = llm_bs1.generate([needle_prompt], sampling)
|
||||
assert len(baseline_out) == 1
|
||||
assert len(baseline_out[0].outputs) >= 1
|
||||
baseline_text = baseline_out[0].outputs[0].text
|
||||
|
||||
# Engine with larger batch limit (e.g., 64)
|
||||
llm_bsN = LLM_with_max_seqs(
|
||||
model=model,
|
||||
max_num_seqs=batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
mismatches = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
# Create a batch of size `batch_size` and insert the needle at
|
||||
# a random index
|
||||
prompts: list[str] = []
|
||||
needle_pos = random.randint(0, batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(_random_prompt())
|
||||
|
||||
# Generate with the larger-batch engine
|
||||
outputs = llm_bsN.generate(prompts, sampling)
|
||||
# Find the needle output by position
|
||||
needle_output = outputs[needle_pos]
|
||||
assert needle_output.prompt == needle_prompt
|
||||
assert len(needle_output.outputs) >= 1
|
||||
text = needle_output.outputs[0].text
|
||||
|
||||
if text != baseline_text:
|
||||
mismatches += 1
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, batch_size={batch_size}")
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (batch_size={batch_size}).")
|
||||
|
||||
finally:
|
||||
# Ensure engines are shutdown to free GPU/VRAM across test sessions
|
||||
if llm_bs1 is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm_bs1.shutdown()
|
||||
if llm_bsN is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm_bsN.shutdown()
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||
|
||||
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# Force float32 to avoid precision-induced differences.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enforce_eager=True, # helps reduce nondeterminism from some backends
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of Germany is",
|
||||
]
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
bs1_logprobs_per_prompt = []
|
||||
for p in prompts:
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# BS=2: run prompts in a batch and collect logprobs per step for each
|
||||
# prompt.
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bs2_logprobs_per_prompt = []
|
||||
for o in outs_batched:
|
||||
step_logprobs = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs2_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
|
||||
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
|
||||
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
|
||||
assert len(logprobs_bs1) == len(logprobs_bs2), (
|
||||
f"Different number of generation steps for prompt index {i}: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
|
||||
assert a.shape == b.shape, (
|
||||
f"Logits shape mismatch at prompt {i}, step {t}: "
|
||||
f"{a.shape} vs {b.shape}")
|
||||
# Bitwise exact equality.
|
||||
assert torch.equal(
|
||||
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||
f"(dtype={a.dtype}, shape={a.shape}).")
|
||||
|
||||
|
||||
def LLM_with_max_seqs(
|
||||
model: str,
|
||||
max_num_seqs: int,
|
||||
gpu_memory_utilization: float,
|
||||
max_model_len: int,
|
||||
swap_space: int,
|
||||
) -> LLM:
|
||||
"""
|
||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||
using the high-level v1 LLM API, while constraining memory usage.
|
||||
"""
|
||||
return LLM(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
# Constrain GPU memory pool so test can run even on busy GPUs.
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
# Keep KV cache footprint small while allowing longer outputs.
|
||||
max_model_len=max_model_len,
|
||||
# Allow some CPU offload if needed.
|
||||
swap_space=swap_space,
|
||||
# Keep things lean and CI-friendly.
|
||||
dtype="float16",
|
||||
# Single-GPU by default; override externally if desired.
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
||||
)
|
||||
561
vllm/model_executor/layers/batch_invariant.py
Normal file
561
vllm/model_executor/layers/batch_invariant.py
Normal file
@ -0,0 +1,561 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
|
||||
args: dict[str, Any]) -> dict[str, Any]:
|
||||
ret = {}
|
||||
m, n, k = args["M"], args["N"], args["K"]
|
||||
ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
|
||||
if "tiles_per_update" in args:
|
||||
ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, "
|
||||
f"tiles_per_update={args['tiles_per_update']:02}]")
|
||||
if "c_ptr" in args:
|
||||
bytes_per_elem = args["c_ptr"].element_size()
|
||||
else:
|
||||
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
|
||||
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
|
||||
ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
|
||||
return ret
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.jit(launch_metadata=_matmul_launch_metadata)
|
||||
def matmul_kernel_persistent(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr, #
|
||||
bias_ptr,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr, #
|
||||
BLOCK_SIZE_N: tl.constexpr, #
|
||||
BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
NUM_SMS: tl.constexpr, #
|
||||
A_LARGE: tl.constexpr,
|
||||
B_LARGE: tl.constexpr,
|
||||
C_LARGE: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
):
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
|
||||
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m,
|
||||
GROUP_SIZE_M, NUM_SMS)
|
||||
start_m = pid_m * BLOCK_SIZE_M
|
||||
start_n = pid_n * BLOCK_SIZE_N
|
||||
offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
|
||||
if A_LARGE:
|
||||
offs_am = offs_am.to(tl.int64)
|
||||
if B_LARGE:
|
||||
offs_bn = offs_bn.to(tl.int64)
|
||||
offs_am = tl.where(offs_am < M, offs_am, 0)
|
||||
offs_bn = tl.where(offs_bn < N, offs_bn, 0)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M),
|
||||
BLOCK_SIZE_M)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
|
||||
BLOCK_SIZE_N)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for ki in range(k_tiles):
|
||||
if A_LARGE or B_LARGE:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(
|
||||
tl.int64)
|
||||
else:
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m,
|
||||
GROUP_SIZE_M, NUM_SMS)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
if HAS_BIAS:
|
||||
bias_ptrs = bias_ptr + offs_cn
|
||||
bias = tl.load(bias_ptrs, mask=offs_cn < N,
|
||||
other=0.0).to(tl.float32)
|
||||
accumulator += bias
|
||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||
c = accumulator.to(tl.float8e4nv)
|
||||
else:
|
||||
c = accumulator.to(tl.float16)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
def matmul_persistent(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
bias: Union[torch.Tensor, None] = None):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.dtype == b.dtype, "Incompatible dtypes"
|
||||
assert bias is None or bias.dim() == 1, (
|
||||
"Currently assuming bias is 1D, let Horace know if you run into this")
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
dtype = a.dtype
|
||||
# Allocates output.
|
||||
c = torch.empty((M, N), device=a.device, dtype=dtype)
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
def grid(META):
|
||||
return (min(
|
||||
NUM_SMS,
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"])), )
|
||||
|
||||
configs = {
|
||||
torch.bfloat16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
torch.float16: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
torch.float32: {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_stages": 3,
|
||||
"num_warps": 8,
|
||||
},
|
||||
}
|
||||
# print(a.device, b.device, c.device)
|
||||
matmul_kernel_persistent[grid](
|
||||
a,
|
||||
b,
|
||||
c, #
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
a.stride(0),
|
||||
a.stride(1), #
|
||||
b.stride(0),
|
||||
b.stride(1), #
|
||||
c.stride(0),
|
||||
c.stride(1), #
|
||||
NUM_SMS=NUM_SMS, #
|
||||
A_LARGE=a.numel() > 2**31,
|
||||
B_LARGE=b.numel() > 2**31,
|
||||
C_LARGE=c.numel() > 2**31,
|
||||
HAS_BIAS=bias is not None,
|
||||
**configs[dtype],
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _log_softmax_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute log_softmax along the last dimension of a 2D tensor.
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
# Get the row index for this block
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
|
||||
# Compute base pointers for input and output rows
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Find maximum value in the row for numerical stability
|
||||
max_val = -float("inf")
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
|
||||
|
||||
# Update maximum
|
||||
max_val = tl.max(tl.maximum(vals, max_val))
|
||||
|
||||
# Step 2: Compute sum of exp(x - max_val)
|
||||
sum_exp = 0.0
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
|
||||
# Compute exp(x - max_val) and accumulate
|
||||
exp_vals = tl.exp(vals - max_val)
|
||||
sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
|
||||
|
||||
# Compute log(sum_exp)
|
||||
log_sum_exp = tl.log(sum_exp)
|
||||
|
||||
# Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
# Load values
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask)
|
||||
|
||||
# Compute log_softmax
|
||||
output = vals - max_val - log_sum_exp
|
||||
|
||||
# Store results
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
Compute log_softmax using Triton kernel.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Dimension along which to compute log_softmax
|
||||
(only -1 or last dim supported)
|
||||
>> Stashed changes
|
||||
Returns:
|
||||
Tensor with log_softmax applied along the specified dimension
|
||||
"""
|
||||
if dim != -1 and dim != input.ndim - 1:
|
||||
raise ValueError("This implementation only supports log_softmax along "
|
||||
"the last dimension")
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
# Allocate output tensor
|
||||
output = torch.empty_like(input_2d)
|
||||
|
||||
# Choose block size based on the number of columns
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
# Launch kernel with one block per row
|
||||
grid = (n_rows, )
|
||||
_log_softmax_kernel[grid](
|
||||
input_2d,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
# Reshape output back to original shape
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mean_kernel(
|
||||
input_ptr,
|
||||
output_ptr,
|
||||
input_stride0,
|
||||
input_stride1,
|
||||
input_stride2,
|
||||
output_stride0,
|
||||
output_stride1,
|
||||
M, # size before reduction dim
|
||||
N, # size of reduction dim
|
||||
K, # size after reduction dim
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Kernel for computing mean along a single dimension.
|
||||
Input is viewed as (M, N, K) where N is the dimension being reduced.
|
||||
"""
|
||||
# Program ID gives us which output element we're computing
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# Compute output indices
|
||||
m_idx = pid // K
|
||||
k_idx = pid % K
|
||||
|
||||
# Bounds check
|
||||
if m_idx >= M or k_idx >= K:
|
||||
return
|
||||
|
||||
# Accumulate sum across reduction dimension
|
||||
acc = 0.0
|
||||
for n_start in range(0, N, BLOCK_SIZE):
|
||||
n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = n_offsets < N
|
||||
|
||||
# Calculate input indices
|
||||
input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \
|
||||
+ k_idx * input_stride2
|
||||
|
||||
# Load and accumulate
|
||||
vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
|
||||
acc += tl.sum(vals)
|
||||
|
||||
# Compute mean and store
|
||||
mean_val = acc / N
|
||||
output_idx = m_idx * output_stride0 + k_idx * output_stride1
|
||||
tl.store(output_ptr + output_idx, mean_val)
|
||||
|
||||
|
||||
def mean_dim(input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: Union[torch.dtype, None] = None) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
|
||||
Args:
|
||||
input: Input tensor
|
||||
dim: Single dimension along which to compute mean
|
||||
keepdim: Whether to keep the reduced dimension
|
||||
dtype: Output dtype. If None, uses input dtype
|
||||
(or float32 for integer inputs)
|
||||
|
||||
Returns:
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||
assert -input.ndim <= dim < input.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions")
|
||||
|
||||
# Handle negative dim
|
||||
if dim < 0:
|
||||
dim = dim + input.ndim
|
||||
|
||||
# Handle dtype
|
||||
if dtype is None:
|
||||
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = input.dtype
|
||||
|
||||
# Convert input to appropriate dtype if needed
|
||||
if input.dtype != dtype:
|
||||
input = input.to(dtype)
|
||||
|
||||
# Get input shape and strides
|
||||
shape = list(input.shape)
|
||||
|
||||
# Calculate dimensions for kernel
|
||||
M = 1
|
||||
for i in range(dim):
|
||||
M *= shape[i]
|
||||
|
||||
N = shape[dim]
|
||||
|
||||
K = 1
|
||||
for i in range(dim + 1, len(shape)):
|
||||
K *= shape[i]
|
||||
|
||||
# Reshape input to 3D view (M, N, K)
|
||||
input_3d = input.reshape(M, N, K)
|
||||
|
||||
# Create output shape
|
||||
if keepdim:
|
||||
output_shape = shape.copy()
|
||||
output_shape[dim] = 1
|
||||
else:
|
||||
output_shape = shape[:dim] + shape[dim + 1:]
|
||||
|
||||
# Create output tensor
|
||||
output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
||||
|
||||
# Reshape output for kernel
|
||||
if keepdim:
|
||||
output_2d = output.reshape(M, 1, K).squeeze(1)
|
||||
else:
|
||||
output_2d = output.reshape(M, K)
|
||||
|
||||
# Launch kernel
|
||||
grid = (M * K, )
|
||||
BLOCK_SIZE = 1024
|
||||
|
||||
mean_kernel[grid](
|
||||
input_3d,
|
||||
output_2d,
|
||||
input_3d.stride(0),
|
||||
input_3d.stride(1),
|
||||
input_3d.stride(2),
|
||||
output_2d.stride(0),
|
||||
output_2d.stride(1) if output_2d.ndim > 1 else 0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def mm_batch_invariant(a, b):
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
return matmul_persistent(a, b, bias=bias)
|
||||
|
||||
|
||||
def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
assert not _half_to_float, "not implemented"
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def mean_batch_invariant(input,
|
||||
dim,
|
||||
keepdim=False,
|
||||
dtype: Union[torch.dtype, None] = None):
|
||||
assert dtype is None or dtype == torch.float32, \
|
||||
f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||
# during iterative reduction.
|
||||
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||
|
||||
# Iteratively apply a deterministic mean.
|
||||
for d in sorted_dims:
|
||||
result = mean_dim(result, dim=d, keepdim=True)
|
||||
|
||||
if not keepdim:
|
||||
# Squeeze the reduced dimensions.
|
||||
for d in sorted_dims:
|
||||
result = result.squeeze(d)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
|
||||
def is_batch_invariant_mode_enabled():
|
||||
return _batch_invariant_MODE
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
_batch_invariant_MODE = True
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::_log_softmax",
|
||||
_log_softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_batch_invariant_mode(enabled: bool = True):
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
|
||||
if enabled:
|
||||
enable_batch_invariant_mode()
|
||||
else:
|
||||
disable_batch_invariant_mode()
|
||||
yield
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
_batch_invariant_MODE, _batch_invariant_LIB = old_data
|
||||
|
||||
|
||||
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
|
||||
|
||||
|
||||
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||
return AttentionBlockSize(block_m=16, block_n=16)
|
||||
|
||||
|
||||
def vllm_kernel_override_batch_invariant():
|
||||
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
val = os.getenv(env_key, "0")
|
||||
try:
|
||||
is_overridden = int(val) != 0
|
||||
except ValueError:
|
||||
is_overridden = False
|
||||
return is_overridden
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
enable_batch_invariant_mode()
|
||||
@ -18,6 +18,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant)
|
||||
from vllm.utils import cdiv, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
@ -839,6 +841,11 @@ def get_kernel_options(query, block_m, block_n,
|
||||
kernel_options: dict[str, Union[int, bool]] = {
|
||||
"FORCE_USE_FLEX_ATTENTION": True,
|
||||
}
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
kernel_options["BLOCK_M"] = 16
|
||||
kernel_options["BLOCK_N"] = 16
|
||||
kernel_options["IS_DIVISIBLE"] = False
|
||||
return kernel_options
|
||||
if use_direct_build:
|
||||
kernel_options["BLOCK_M"] = block_m
|
||||
kernel_options["BLOCK_N"] = block_n
|
||||
|
||||
@ -192,6 +192,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
init_batch_invariance)
|
||||
init_batch_invariance()
|
||||
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user