From dc48ba0c750e176f6314504cc0e8370a46ed01a8 Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Fri, 26 Sep 2025 19:59:09 -0400 Subject: [PATCH] Kernel-override Determinism [1/n] (#25603) Signed-off-by: Bram Wasti --- csrc/core/batch_invariant.hpp | 16 + csrc/layernorm_kernels.cu | 8 +- csrc/layernorm_quant_kernels.cu | 5 +- csrc/moe/topk_softmax_kernels.cu | 4 +- tests/v1/generation/test_batch_invariance.py | 290 +++++++++ vllm/model_executor/layers/batch_invariant.py | 561 ++++++++++++++++++ vllm/v1/attention/backends/flex_attention.py | 7 + vllm/v1/worker/gpu_model_runner.py | 3 + 8 files changed, 890 insertions(+), 4 deletions(-) create mode 100644 csrc/core/batch_invariant.hpp create mode 100644 tests/v1/generation/test_batch_invariance.py create mode 100644 vllm/model_executor/layers/batch_invariant.py diff --git a/csrc/core/batch_invariant.hpp b/csrc/core/batch_invariant.hpp new file mode 100644 index 000000000000..19e422e4b80c --- /dev/null +++ b/csrc/core/batch_invariant.hpp @@ -0,0 +1,16 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +// vllm_kernel_override_batch_invariant(); returns true +// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1 +inline bool vllm_kernel_override_batch_invariant() { + std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"; + const char* val = std::getenv(env_key.c_str()); + return (val && std::atoi(val) != 0) ? 1 : 0; +} + +} // namespace vllm diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 93c73d58390e..6c3685f6f7cd 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,6 +1,7 @@ #include "type_convert.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" +#include "core/batch_invariant.hpp" #include #include @@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] wt_ptr % req_alignment_bytes == 0; bool offsets_are_multiple_of_vector_width = hidden_size % vector_width == 0 && input_stride % vector_width == 0; - if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && offsets_are_multiple_of_vector_width && + !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); @@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size] auto inp_ptr = reinterpret_cast(input.data_ptr()); auto out_ptr = reinterpret_cast(out.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) { LAUNCH_FUSED_POLY_NORM(8); } else { LAUNCH_FUSED_POLY_NORM(0); diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index be134089bd6d..58c3d9c0981a 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -9,6 +9,7 @@ #include "quantization/fp8/common.cuh" #include "dispatch_utils.h" #include "cub_helpers.h" +#include "core/batch_invariant.hpp" #include #include @@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) { + bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 && + !batch_invariant_launch) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 53573ada86ba..eca021f1c186 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -21,6 +21,7 @@ #include #include "../cuda_compat.h" #include "../cub_helpers.h" +#include "../core/batch_invariant.hpp" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; - const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant(); + const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py new file mode 100644 index 000000000000..b864f9a31836 --- /dev/null +++ b/tests/v1/generation/test_batch_invariance.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +import random +import string + +import pytest +import torch + +from vllm import LLM, SamplingParams + + +def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: + # Lightweight random prompt generator to vary prompt lengths and content. + vocab = [ + "alpha", + "bravo", + "charlie", + "delta", + "echo", + "foxtrot", + "golf", + "hotel", + "india", + "juliet", + "kilo", + "lima", + "mike", + "november", + "oscar", + "papa", + "quebec", + "romeo", + "sierra", + "tango", + "uniform", + "victor", + "whiskey", + "xray", + "yankee", + "zulu", + ] + n = random.randint(min_words, max_words) + words = random.choices(vocab, k=n) + + # Add some noise and punctuation variability + if random.random() < 0.5: + words[0] = words[0].capitalize() + if random.random() < 0.2: + words.append("".join(random.choices(string.ascii_lowercase, k=5))) + punct = random.choice([".", "?", "!", "...", ""]) + return " ".join(words) + punct + + +@pytest.mark.timeout(1000) +def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): + """ + Ensures that the same request (the 'needle' prompt) yields identical output + whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64), + using the high-level v1 LLM() API only (no manual batching). + + Strategy: + - Create two LLM engines with identical config except max_num_seqs: 1 vs N. + - Compute a baseline output for the needle prompt with the bs=1 engine. + - For many trials, generate a batch (size N) where the needle appears at a + random position among random filler prompts using the bs=N engine. + - Track how many trials match vs mismatch, and report totals at the end. + The test fails if any mismatches occur, but we still dump pass/fail + counts. + + Notes: + - Use seeded stochastic sampling with a fixed seed to test determinism. + - Outputs are intentionally longer and sampled at higher temperature/top_p + to produce a more random-sounding phrase, yet remain deterministic by + seed. + - Keep max_tokens and max_model_len bounded for speed and memory use. + """ + random.seed(12345) + + # Allow overrides from environment (useful for CI tuning) + # "facebook/opt-125m" is too small, doesn't reliably test determinism + model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) + batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64")) + assert batch_size >= 2, "Batch size should be >= 2 to mix needle." + + # Keep GPU memory usage low to avoid startup allocation failures. + gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3")) + max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096")) + swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) + + # Sampling parameters: longer outputs with a more random-sounding + # continuation,but still deterministic due to fixed seed. + temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0")) + top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95")) + max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128")) + + sampling = SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + seed=20240919, + ) + + needle_prompt = ("There once was a ") + + llm_bs1 = None + llm_bsN = None + try: + # Engine with bs=1 behavior + llm_bs1 = LLM_with_max_seqs( + model=model, + max_num_seqs=1, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + swap_space=swap_space_gb, + ) + + # Baseline generation for the needle prompt alone. + baseline_out = llm_bs1.generate([needle_prompt], sampling) + assert len(baseline_out) == 1 + assert len(baseline_out[0].outputs) >= 1 + baseline_text = baseline_out[0].outputs[0].text + + # Engine with larger batch limit (e.g., 64) + llm_bsN = LLM_with_max_seqs( + model=model, + max_num_seqs=batch_size, + gpu_memory_utilization=gpu_mem_util, + max_model_len=max_model_len, + swap_space=swap_space_gb, + ) + + mismatches = 0 + + for trial in range(num_trials): + # Create a batch of size `batch_size` and insert the needle at + # a random index + prompts: list[str] = [] + needle_pos = random.randint(0, batch_size - 1) + for i in range(batch_size): + if i == needle_pos: + prompts.append(needle_prompt) + else: + prompts.append(_random_prompt()) + + # Generate with the larger-batch engine + outputs = llm_bsN.generate(prompts, sampling) + # Find the needle output by position + needle_output = outputs[needle_pos] + assert needle_output.prompt == needle_prompt + assert len(needle_output.outputs) >= 1 + text = needle_output.outputs[0].text + + if text != baseline_text: + mismatches += 1 + + passes = num_trials - mismatches + # Dump how many passed vs failed + print(f"[determinism] total={num_trials}, passed={passes}, " + f"failed={mismatches}, batch_size={batch_size}") + + if mismatches > 0: + pytest.fail( + f"Nondeterministic outputs detected: {mismatches} failed out " + f"of {num_trials} trials (batch_size={batch_size}).") + + finally: + # Ensure engines are shutdown to free GPU/VRAM across test sessions + if llm_bs1 is not None: + with contextlib.suppress(Exception): + llm_bs1.shutdown() + if llm_bsN is not None: + with contextlib.suppress(Exception): + llm_bsN.shutdown() + + +def _extract_step_logprobs(request_output): + if getattr(request_output, "outputs", None): + inner = request_output.outputs[0] + if hasattr(inner, "logprobs") and inner.logprobs is not None: + t = torch.tensor( + [ + inner.logprobs[i][tid].logprob + for i, tid in enumerate(inner.token_ids) + ], + dtype=torch.float32, + ) + return t + + return None + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Requires CUDA to match production inference path.", +) +def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2(): + + #model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m") + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + # Force float32 to avoid precision-induced differences. + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enforce_eager=True, # helps reduce nondeterminism from some backends + ) + + prompts = [ + "The capital of France is", + "The capital of Germany is", + ] + + sp = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=8, + # Seed shouldn't matter at temperature=0, but keeping it stable anyway. + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + bs1_logprobs_per_prompt = [] + for p in prompts: + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bs1_logprobs_per_prompt.append(step_logprobs) + + # BS=2: run prompts in a batch and collect logprobs per step for each + # prompt. + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bs2_logprobs_per_prompt = [] + for o in outs_batched: + step_logprobs = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip("Logits are not available on RequestOutput; " + "enable logprobs return to run this test.") + bs2_logprobs_per_prompt.append(step_logprobs) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs. + for i, (logprobs_bs1, logprobs_bs2) in enumerate( + zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)): + assert len(logprobs_bs1) == len(logprobs_bs2), ( + f"Different number of generation steps for prompt index {i}: " + f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)") + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)): + assert a.shape == b.shape, ( + f"Logits shape mismatch at prompt {i}, step {t}: " + f"{a.shape} vs {b.shape}") + # Bitwise exact equality. + assert torch.equal( + a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} " + f"(dtype={a.dtype}, shape={a.shape}).") + + +def LLM_with_max_seqs( + model: str, + max_num_seqs: int, + gpu_memory_utilization: float, + max_model_len: int, + swap_space: int, +) -> LLM: + """ + Helper to construct an LLM with a specific max_num_seqs (batch-size limit) + using the high-level v1 LLM API, while constraining memory usage. + """ + return LLM( + model=model, + max_num_seqs=max_num_seqs, + # Constrain GPU memory pool so test can run even on busy GPUs. + gpu_memory_utilization=gpu_memory_utilization, + # Keep KV cache footprint small while allowing longer outputs. + max_model_len=max_model_len, + # Allow some CPU offload if needed. + swap_space=swap_space, + # Keep things lean and CI-friendly. + dtype="float16", + # Single-GPU by default; override externally if desired. + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1", + ) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py new file mode 100644 index 000000000000..ae2c842af698 --- /dev/null +++ b/vllm/model_executor/layers/batch_invariant.py @@ -0,0 +1,561 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import os +from collections import namedtuple +from collections.abc import Callable +from typing import Any, Union + +import torch +import triton +import triton.language as tl + + +def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, + args: dict[str, Any]) -> dict[str, Any]: + ret = {} + m, n, k = args["M"], args["N"], args["K"] + ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]" + if "tiles_per_update" in args: + ret["name"] = (f"{kernel.name} [M={m}, N={n}, K={k}, " + f"tiles_per_update={args['tiles_per_update']:02}]") + if "c_ptr" in args: + bytes_per_elem = args["c_ptr"].element_size() + else: + bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 + ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k + ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n) + return ret + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.jit(launch_metadata=_matmul_launch_metadata) +def matmul_kernel_persistent( + a_ptr, + b_ptr, + c_ptr, # + bias_ptr, + M, + N, + K, # + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, # + BLOCK_SIZE_N: tl.constexpr, # + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + NUM_SMS: tl.constexpr, # + A_LARGE: tl.constexpr, + B_LARGE: tl.constexpr, + C_LARGE: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + + tile_id_c = start_pid - NUM_SMS + + offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, + GROUP_SIZE_M, NUM_SMS) + start_m = pid_m * BLOCK_SIZE_M + start_n = pid_n * BLOCK_SIZE_N + offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) + offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) + if A_LARGE: + offs_am = offs_am.to(tl.int64) + if B_LARGE: + offs_bn = offs_bn.to(tl.int64) + offs_am = tl.where(offs_am < M, offs_am, 0) + offs_bn = tl.where(offs_bn < N, offs_bn, 0) + offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), + BLOCK_SIZE_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), + BLOCK_SIZE_N) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + if A_LARGE or B_LARGE: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to( + tl.int64) + else: + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + a = tl.load(a_ptrs, + mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, + other=0.0) + accumulator = tl.dot(a, b, accumulator) + + tile_id_c += NUM_SMS + pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, + GROUP_SIZE_M, NUM_SMS) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if C_LARGE: + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if HAS_BIAS: + bias_ptrs = bias_ptr + offs_cn + bias = tl.load(bias_ptrs, mask=offs_cn < N, + other=0.0).to(tl.float32) + accumulator += bias + if c_ptr.dtype.element_ty == tl.float8e4nv: + c = accumulator.to(tl.float8e4nv) + else: + c = accumulator.to(tl.float16) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul_persistent(a: torch.Tensor, + b: torch.Tensor, + bias: Union[torch.Tensor, None] = None): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + assert bias is None or bias.dim() == 1, ( + "Currently assuming bias is 1D, let Horace know if you run into this") + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return (min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"])), ) + + configs = { + torch.bfloat16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float16: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + torch.float32: { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_stages": 3, + "num_warps": 8, + }, + } + # print(a.device, b.device, c.device) + matmul_kernel_persistent[grid]( + a, + b, + c, # + bias, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + NUM_SMS=NUM_SMS, # + A_LARGE=a.numel() > 2**31, + B_LARGE=b.numel() > 2**31, + C_LARGE=c.numel() > 2**31, + HAS_BIAS=bias is not None, + **configs[dtype], + ) + return c + + +@triton.jit +def _log_softmax_kernel( + input_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute log_softmax along the last dimension of a 2D tensor. + Each block handles one row of the input tensor. + """ + # Get the row index for this block + row_idx = tl.program_id(0).to(tl.int64) + + # Compute base pointers for input and output rows + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Find maximum value in the row for numerical stability + max_val = -float("inf") + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf")) + + # Update maximum + max_val = tl.max(tl.maximum(vals, max_val)) + + # Step 2: Compute sum of exp(x - max_val) + sum_exp = 0.0 + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + + # Compute exp(x - max_val) and accumulate + exp_vals = tl.exp(vals - max_val) + sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0)) + + # Compute log(sum_exp) + log_sum_exp = tl.log(sum_exp) + + # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + # Load values + vals = tl.load(row_start_ptr + col_idx, mask=mask) + + # Compute log_softmax + output = vals - max_val - log_sum_exp + + # Store results + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: + """ + Compute log_softmax using Triton kernel. + + Args: + input: Input tensor + dim: Dimension along which to compute log_softmax + (only -1 or last dim supported) + >> Stashed changes + Returns: + Tensor with log_softmax applied along the specified dimension + """ + if dim != -1 and dim != input.ndim - 1: + raise ValueError("This implementation only supports log_softmax along " + "the last dimension") + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + + n_rows, n_cols = input_2d.shape + + # Allocate output tensor + output = torch.empty_like(input_2d) + + # Choose block size based on the number of columns + BLOCK_SIZE = 1024 + + # Launch kernel with one block per row + grid = (n_rows, ) + _log_softmax_kernel[grid]( + input_2d, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + ) + # Reshape output back to original shape + return output.reshape(original_shape) + + +@triton.jit +def mean_kernel( + input_ptr, + output_ptr, + input_stride0, + input_stride1, + input_stride2, + output_stride0, + output_stride1, + M, # size before reduction dim + N, # size of reduction dim + K, # size after reduction dim + BLOCK_SIZE: tl.constexpr, +): + """ + Kernel for computing mean along a single dimension. + Input is viewed as (M, N, K) where N is the dimension being reduced. + """ + # Program ID gives us which output element we're computing + pid = tl.program_id(0) + + # Compute output indices + m_idx = pid // K + k_idx = pid % K + + # Bounds check + if m_idx >= M or k_idx >= K: + return + + # Accumulate sum across reduction dimension + acc = 0.0 + for n_start in range(0, N, BLOCK_SIZE): + n_offsets = n_start + tl.arange(0, BLOCK_SIZE) + mask = n_offsets < N + + # Calculate input indices + input_idx = m_idx * input_stride0 + n_offsets * input_stride1 \ + + k_idx * input_stride2 + + # Load and accumulate + vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0) + acc += tl.sum(vals) + + # Compute mean and store + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 + tl.store(output_ptr + output_idx, mean_val) + + +def mean_dim(input: torch.Tensor, + dim: int, + keepdim: bool = False, + dtype: Union[torch.dtype, None] = None) -> torch.Tensor: + """ + Triton implementation of torch.mean with single dimension reduction. + + Args: + input: Input tensor + dim: Single dimension along which to compute mean + keepdim: Whether to keep the reduced dimension + dtype: Output dtype. If None, uses input dtype + (or float32 for integer inputs) + + Returns: + Tensor with mean values along specified dimension + """ + # Validate inputs + assert input.is_cuda, "Input must be a CUDA tensor" + assert -input.ndim <= dim < input.ndim, ( + f"Invalid dimension {dim} for tensor with {input.ndim} dimensions") + + # Handle negative dim + if dim < 0: + dim = dim + input.ndim + + # Handle dtype + if dtype is None: + if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + dtype = torch.float32 + else: + dtype = input.dtype + + # Convert input to appropriate dtype if needed + if input.dtype != dtype: + input = input.to(dtype) + + # Get input shape and strides + shape = list(input.shape) + + # Calculate dimensions for kernel + M = 1 + for i in range(dim): + M *= shape[i] + + N = shape[dim] + + K = 1 + for i in range(dim + 1, len(shape)): + K *= shape[i] + + # Reshape input to 3D view (M, N, K) + input_3d = input.reshape(M, N, K) + + # Create output shape + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1:] + + # Create output tensor + output = torch.empty(output_shape, dtype=dtype, device=input.device) + + # Reshape output for kernel + if keepdim: + output_2d = output.reshape(M, 1, K).squeeze(1) + else: + output_2d = output.reshape(M, K) + + # Launch kernel + grid = (M * K, ) + BLOCK_SIZE = 1024 + + mean_kernel[grid]( + input_3d, + output_2d, + input_3d.stride(0), + input_3d.stride(1), + input_3d.stride(2), + output_2d.stride(0), + output_2d.stride(1) if output_2d.ndim > 1 else 0, + M, + N, + K, + BLOCK_SIZE, + ) + + return output + + +def mm_batch_invariant(a, b): + return matmul_persistent(a, b) + + +def addmm_batch_invariant(bias, a, b): + return matmul_persistent(a, b, bias=bias) + + +def _log_softmax_batch_invariant(input, dim, _half_to_float): + assert not _half_to_float, "not implemented" + return log_softmax(input, dim=dim) + + +def mean_batch_invariant(input, + dim, + keepdim=False, + dtype: Union[torch.dtype, None] = None): + assert dtype is None or dtype == torch.float32, \ + f"unsupported dtype: {dtype}" + + result = input.to(torch.float32) + + # Sort dimensions to reduce from largest to smallest to handle shifting dims + # during iterative reduction. + sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) + + # Iteratively apply a deterministic mean. + for d in sorted_dims: + result = mean_dim(result, dim=d, keepdim=True) + + if not keepdim: + # Squeeze the reduced dimensions. + for d in sorted_dims: + result = result.squeeze(d) + + return result + + +_batch_invariant_MODE = False +_batch_invariant_LIB = None + + +def is_batch_invariant_mode_enabled(): + return _batch_invariant_MODE + + +def enable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB + if _batch_invariant_MODE: + return + + _batch_invariant_MODE = True + _batch_invariant_LIB = torch.library.Library("aten", "IMPL") + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::_log_softmax", + _log_softmax_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") + + +def disable_batch_invariant_mode(): + global _batch_invariant_MODE, _batch_invariant_LIB + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE = False + _batch_invariant_LIB = None + + +@contextlib.contextmanager +def set_batch_invariant_mode(enabled: bool = True): + global _batch_invariant_MODE, _batch_invariant_LIB + old_data = (_batch_invariant_MODE, _batch_invariant_LIB) + if enabled: + enable_batch_invariant_mode() + else: + disable_batch_invariant_mode() + yield + if _batch_invariant_LIB is not None: + _batch_invariant_LIB._destroy() + _batch_invariant_MODE, _batch_invariant_LIB = old_data + + +AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) + + +def get_batch_invariant_attention_block_size() -> AttentionBlockSize: + return AttentionBlockSize(block_m=16, block_n=16) + + +def vllm_kernel_override_batch_invariant(): + env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT" + is_overridden = False + val = os.getenv(env_key, "0") + try: + is_overridden = int(val) != 0 + except ValueError: + is_overridden = False + return is_overridden + + +def init_batch_invariance(): + # this will hit all the csrc overrides as well + if vllm_kernel_override_batch_invariant(): + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + enable_batch_invariant_mode() diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index c3358bfa74e9..807b8d987a2d 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -18,6 +18,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant) from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -839,6 +841,11 @@ def get_kernel_options(query, block_m, block_n, kernel_options: dict[str, Union[int, bool]] = { "FORCE_USE_FLEX_ATTENTION": True, } + if vllm_kernel_override_batch_invariant(): + kernel_options["BLOCK_M"] = 16 + kernel_options["BLOCK_N"] = 16 + kernel_options["IS_DIVISIBLE"] = False + return kernel_options if use_direct_build: kernel_options["BLOCK_M"] = block_m kernel_options["BLOCK_N"] = block_n diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4fd4f9128c6e..f87a327d02a5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -192,6 +192,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import ( + init_batch_invariance) + init_batch_invariance() model_config = self.model_config cache_config = self.cache_config