Deepseek-v3 Batch Invariant on 8xH100 (#26609)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Bram Wasti 2025-10-15 22:06:02 -07:00 committed by GitHub
parent 785d8b6410
commit 7d8975de84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1567 additions and 102 deletions

View File

@ -3,56 +3,73 @@
import contextlib
import os
import random
import string
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode():
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
yield
# Restore original value after test
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Lightweight random prompt generator to vary prompt lengths and content.
vocab = [
"alpha",
"bravo",
"charlie",
"delta",
"echo",
"foxtrot",
"golf",
"hotel",
"india",
"juliet",
"kilo",
"lima",
"mike",
"november",
"oscar",
"papa",
"quebec",
"romeo",
"sierra",
"tango",
"uniform",
"victor",
"whiskey",
"xray",
"yankee",
"zulu",
# Generate more realistic prompts that will actually produce varied tokens
# Use a mix of common English text patterns
prompt_templates = [
# Question-answer style
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
# Story/narrative style
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
# Technical/code style
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
# Factual/informative style
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
# Conversational style
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
n = random.randint(min_words, max_words)
words = random.choices(vocab, k=n)
# Add some noise and punctuation variability
if random.random() < 0.5:
words[0] = words[0].capitalize()
if random.random() < 0.2:
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
punct = random.choice([".", "?", "!", "...", ""])
return " ".join(words) + punct
# Pick a random template
base_prompt = random.choice(prompt_templates)
# Add some padding to vary the length if needed
if min_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. "
* (min_words // 50)
)
base_prompt = base_prompt + padding_text
return base_prompt
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.timeout(1000)
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
"""
@ -91,7 +108,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
# Sampling parameters: longer outputs with a more random-sounding
# continuation,but still deterministic due to fixed seed.
@ -117,7 +133,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
swap_space=swap_space_gb,
)
# Baseline generation for the needle prompt alone.
@ -132,7 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
swap_space=swap_space_gb,
)
mismatches = 0
@ -195,16 +209,21 @@ def _extract_step_logprobs(request_output):
],
dtype=torch.float32,
)
return t
return t, inner.token_ids
return None
return None, None
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Requires CUDA to match production inference path.",
)
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
@ -214,69 +233,754 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# Force float32 to avoid precision-induced differences.
# For batch invariance, disable custom all-reduce to ensure deterministic
# all-reduce operations (custom all-reduce may not be deterministic)
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
disable_custom_ar = vllm_kernel_override_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enforce_eager=True,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported
)
prompts = [_random_prompt(10, 1024) for i in range(100)]
# Use more realistic prompts for better token generation
prompts = [_random_prompt(10, 50) for i in range(32)]
sp = SamplingParams(
temperature=0.6,
top_p=1.0,
max_tokens=8,
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
seed=1234,
logprobs=5,
)
# BS=1: run prompts individually and collect logprobs per step.
print("\n" + "=" * 80)
print("STARTING BS=1 RUNS (each prompt individually)")
print("=" * 80 + "\n")
bs1_logprobs_per_prompt = []
for p in prompts:
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs = _extract_step_logprobs(outs[0])
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
if step_logprobs is None:
pytest.skip(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs1_logprobs_per_prompt.append(step_logprobs)
bs1_tokens_per_prompt.append(token_ids)
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
# BS=N: run prompts in a batch and collect logprobs per step for each
# prompt.
print("\n" + "=" * 80)
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
print("=" * 80 + "\n")
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts)
bsN_logprobs_per_prompt = []
for o in outs_batched:
step_logprobs = _extract_step_logprobs(o)
bsN_tokens_per_prompt = []
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
for idx, o in enumerate(outs_batched):
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
step_logprobs, token_ids = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bsN_logprobs_per_prompt.append(step_logprobs)
bsN_tokens_per_prompt.append(token_ids)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
):
assert len(logprobs_bs1) == len(logprobs_bsN), (
f"Different number of generation steps for prompt index {i}: "
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
failed_prompts = []
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)
):
if len(logprobs_bs1) != len(logprobs_bsN):
reason = (
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
f"vs {len(logprobs_bsN)} (BS=N)"
)
failed_prompts.append(
{
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
continue
# Check if tokens match first
if tokens_bs1 != tokens_bsN:
failed_prompts.append(
{
"prompt_idx": i,
"step": "sampling",
"reason": "Different tokens sampled",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
"bs1_all_logprobs": [
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
],
"bsN_all_logprobs": [
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
],
}
)
continue
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
assert a.shape == b.shape, (
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
if a.shape != b.shape:
failed_prompts.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
break
if not torch.equal(a, b):
max_diff = torch.abs(a - b).max().item()
# Print which token failed
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
print(f" BS=1 logprob: {a.tolist()}")
print(f" BS=N logprob: {b.tolist()}")
failed_prompts.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
"bs1_all_logprobs": [
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
],
"bsN_all_logprobs": [
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
],
}
)
break
# Print summary of all failures
if failed_prompts:
print(f"\n{'=' * 80}")
fail_msg = (
f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
f"{len(prompts)} prompts failed"
)
print(fail_msg)
print(f"{'=' * 80}")
for fail in failed_prompts:
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
print(f" Reason: {fail['reason']}")
print(f" Preview: {fail['prompt_preview']}...")
# Always show the tokens
if "bs1_tokens" in fail:
print(f" BS=1 tokens: {fail['bs1_tokens']}")
if "bsN_tokens" in fail:
print(f" BS=N tokens: {fail['bsN_tokens']}")
if "bs1_all_logprobs" in fail:
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
print(f" Step {step_idx}: {logprobs}")
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
print(f" Step {step_idx}: {logprobs}")
print(f"{'=' * 80}\n")
# Fail the test with summary
msg = (
f"Batch invariance violated in {len(failed_prompts)}/"
f"{len(prompts)} prompts. See output above for details."
)
pytest.fail(msg)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
def test_simple_generation():
"""
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
llm = LLM(
model=model,
max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enforce_eager=True,
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
enable_prefix_caching=False,
)
prompt = "the capital of france is"
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=20,
)
print(f"\n{'=' * 80}")
print("Running simple generation test")
print(f"Prompt: '{prompt}'")
print(f"{'=' * 80}\n")
try:
outputs = llm.generate([prompt], sampling_params)
assert len(outputs) == 1
output_text = outputs[0].outputs[0].text
print(f"Output: '{output_text}'")
print(f"\n{'=' * 80}")
print(f"Full completion: '{prompt}{output_text}'")
print(f"{'=' * 80}\n")
finally:
with contextlib.suppress(Exception):
llm.shutdown()
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Requires CUDA to match production inference path.",
)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
@pytest.mark.forked
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
"""
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
It DISABLES batch invariance mode and expects to see non-deterministic behavior
between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
doing something useful.
The test will PASS if we detect differences (proving batch invariance matters).
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
"""
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
# CRITICAL: Disable batch invariance for this test
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
try:
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}")
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
)
# Use more realistic prompts for better token generation
prompts = [_random_prompt(10, 50) for i in range(32)]
sp = SamplingParams(
temperature=0.6,
top_p=1.0,
max_tokens=8,
seed=1234,
logprobs=5,
)
# BS=1: run prompts individually and collect logprobs per step.
print("\n" + "=" * 80)
print("STARTING BS=1 RUNS (each prompt individually)")
print("=" * 80 + "\n")
bs1_logprobs_per_prompt = []
bs1_tokens_per_prompt = []
for idx, p in enumerate(prompts):
print(
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
)
# Bitwise exact equality.
assert torch.equal(a, b), (
f"Bitwise logprobs mismatch at prompt {i}, step {t} "
f"(dtype={a.dtype}, shape={a.shape})."
outs = llm.generate([p], sp, use_tqdm=False)
assert len(outs) == 1
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
if step_logprobs is None:
pytest.skip(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bs1_logprobs_per_prompt.append(step_logprobs)
bs1_tokens_per_prompt.append(token_ids)
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
print("\n" + "=" * 80)
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
print("=" * 80 + "\n")
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts)
bsN_logprobs_per_prompt = []
bsN_tokens_per_prompt = []
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
for idx, o in enumerate(outs_batched):
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
step_logprobs, token_ids = _extract_step_logprobs(o)
if step_logprobs is None:
pytest.skip(
"Logits are not available on RequestOutput; "
"enable logprobs return to run this test."
)
bsN_logprobs_per_prompt.append(step_logprobs)
bsN_tokens_per_prompt.append(token_ids)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
differences_found = []
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
zip(
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
)
):
if len(logprobs_bs1) != len(logprobs_bsN):
reason = (
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
f"vs {len(logprobs_bsN)} (BS=N)"
)
differences_found.append(
{
"prompt_idx": i,
"step": "all",
"reason": reason,
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
continue
# Check if tokens match first
if tokens_bs1 != tokens_bsN:
differences_found.append(
{
"prompt_idx": i,
"step": "sampling",
"reason": "Different tokens sampled",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
continue
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a.shape != b.shape:
differences_found.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
break
if not torch.equal(a, b):
max_diff = torch.abs(a - b).max().item()
print(
f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
f"Token {t}: max_diff={max_diff:.6e}"
)
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
print(f" BS=1 logprob: {a.tolist()}")
print(f" BS=N logprob: {b.tolist()}")
differences_found.append(
{
"prompt_idx": i,
"step": t,
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
"prompt_preview": prompts[i][:100],
"bs1_tokens": tokens_bs1,
"bsN_tokens": tokens_bsN,
}
)
break
# Print summary
print(f"\n{'=' * 80}")
if differences_found:
success_msg = (
f"✓ SUCCESS: Batch invariance is doing something! "
f"Found {len(differences_found)}/{len(prompts)} prompts "
f"with differences when batch invariance was DISABLED."
)
print(success_msg)
print(f"{'=' * 80}")
for diff in differences_found:
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
print(f" Reason: {diff['reason']}")
print(f" Preview: {diff['prompt_preview']}...")
if "bs1_tokens" in diff:
print(f" BS=1 tokens: {diff['bs1_tokens']}")
if "bsN_tokens" in diff:
print(f" BS=N tokens: {diff['bsN_tokens']}")
print(f"{'=' * 80}\n")
# Test PASSES because we found differences (batch invariance matters!)
return
else:
# Test FAILS because everything matched even without batch invariance
fail_msg = (
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
f"between BS=1 and BS=N even with batch invariance DISABLED. "
f"This suggests batch invariance might not be necessary, "
f"or the test needs more sensitive prompts."
)
print(fail_msg)
print(f"{'=' * 80}\n")
pytest.fail(fail_msg)
finally:
# Restore original value
if old_value is None:
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
else:
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="Requires CUDA to match production inference path.",
)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(backend):
"""
Test that verifies decode logprobs match prefill logprobs.
For each decoded token at position i:
1. Run decode to generate N tokens and collect their logprobs
2. For each position i in [0, N):
- Take prefix = prompt + tokens[0:i]
- Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
- Verify prefill logprob matches decode logprob bitwise
This ensures that the logprobs from decode are consistent with what
we would get if we ran prefill on each prefix.
"""
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
os.environ["VLLM_ATTENTION_BACKEND"] = backend
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
disable_custom_ar = vllm_kernel_override_batch_invariant()
if disable_custom_ar:
print(f"\n{'=' * 80}")
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16",
)
# Use a few test prompts
num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4"))
prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)]
# Generate longer sequences to test multiple decode steps
max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16"))
sp = SamplingParams(
temperature=0.0, # Greedy for determinism
max_tokens=max_tokens,
logprobs=5,
)
print("\n" + "=" * 80)
print("STEP 1: Running decode to generate tokens and collect logprobs")
print("=" * 80 + "\n")
# Step 1: Run decode and collect logprobs
decode_outputs = llm.generate(prompts, sp, use_tqdm=False)
failed_comparisons = []
for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)):
print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...")
# Extract decode logprobs and tokens
decode_logprobs, token_ids = _extract_step_logprobs(decode_output)
if decode_logprobs is None:
pytest.skip(
"Logprobs are not available on RequestOutput; "
"enable logprobs return to run this test."
)
print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}")
print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}")
# Step 2: For each token position, run prefill and compare
print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...")
for token_idx in range(len(token_ids)):
# Construct the prefix up to (but not including) this token
current_token = token_ids[token_idx]
# We need to detokenize to get the text prefix
# For this, we'll use the tokenizer from the LLM
# However, the LLM API doesn't expose tokenizer easily, so we'll
# construct the prefix by decoding from the original prompt
# Get text up to this point by using the output text
# This is approximate but should work for verification
if token_idx == 0:
prefix_prompt = prompt
else:
# Use the partial output text up to this token
# We'll need to construct this from the full output
prefix_output = decode_output.outputs[0]
# Get the text for tokens 0 to token_idx-1
# Unfortunately, we don't have per-token text, so we'll use
# a different approach: run prefill with prompt + tokens[0:token_idx]
# Actually, we need to get the actual text. Let's use a workaround:
# Run a generation with max_tokens = token_idx to get that prefix
prefix_sp = SamplingParams(
temperature=0.0,
max_tokens=token_idx,
logprobs=1,
)
prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0]
prefix_prompt = prompt + prefix_output.outputs[0].text
# Now run prefill with max_tokens=1 to get the logprob of the next token
prefill_sp = SamplingParams(
temperature=0.0,
max_tokens=1,
logprobs=5,
)
print(
f" [Token {token_idx}] Running prefill for prefix "
f"(len={len(prefix_prompt)})..."
)
prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[
0
]
prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output)
if prefill_logprobs is None:
print(f" [Token {token_idx}] Warning: No prefill logprobs available")
continue
# The first token from prefill should match the current token
prefill_token = prefill_token_ids[0]
prefill_logprob = prefill_logprobs[0].item()
decode_logprob = decode_logprobs[token_idx].item()
print(
f" [Token {token_idx}] Decode token: {current_token}, "
f"logprob: {decode_logprob:.8f}"
)
print(
f" [Token {token_idx}] Prefill token: {prefill_token}, "
f"logprob: {prefill_logprob:.8f}"
)
# Check if tokens match
if current_token != prefill_token:
failed_comparisons.append(
{
"prompt_idx": prompt_idx,
"token_idx": token_idx,
"reason": "Token mismatch",
"decode_token": current_token,
"prefill_token": prefill_token,
"decode_logprob": decode_logprob,
"prefill_logprob": prefill_logprob,
"prompt_text": prompt[:100],
"prefix_text": prefix_prompt[:100],
}
)
print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!")
continue
# Check if logprobs match bitwise
if decode_logprob != prefill_logprob:
diff = abs(decode_logprob - prefill_logprob)
failed_comparisons.append(
{
"prompt_idx": prompt_idx,
"token_idx": token_idx,
"reason": "Logprob mismatch",
"decode_token": current_token,
"prefill_token": prefill_token,
"decode_logprob": decode_logprob,
"prefill_logprob": prefill_logprob,
"diff": diff,
"prompt_text": prompt[:100],
"prefix_text": prefix_prompt[:100],
"decode_all_tokens": token_ids,
"decode_all_logprobs": decode_logprobs.tolist(),
}
)
print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}")
else:
print(f" [Token {token_idx}] ✓ Match (bitwise equal)")
# Print summary
print(f"\n{'=' * 80}")
if failed_comparisons:
print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected")
print(f"{'=' * 80}")
# Group failures by prompt for better readability
failures_by_prompt: dict[int, list[dict]] = {}
for fail in failed_comparisons:
pid = fail["prompt_idx"]
if pid not in failures_by_prompt:
failures_by_prompt[pid] = []
failures_by_prompt[pid].append(fail)
for prompt_idx, failures in failures_by_prompt.items():
print(f"\n{'=' * 80}")
print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...")
print(f"{'=' * 80}")
print(f"Total failures for this prompt: {len(failures)}")
# Show where mismatches occur (which token positions)
mismatch_positions = [f["token_idx"] for f in failures]
print(f"Mismatch at token positions: {mismatch_positions}")
# Show first few failures in detail
for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt
print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:")
print(f" Reason: {fail['reason']}")
print(f" Prefix text: '{fail['prefix_text']}...'")
print(
f" Decode: token={fail['decode_token']}, "
f"logprob={fail['decode_logprob']:.10f}"
)
print(
f" Prefill: token={fail['prefill_token']}, "
f"logprob={fail['prefill_logprob']:.10f}"
)
if "diff" in fail:
print(f" Difference: {fail['diff']:.10e}")
# Show in hex to see bitwise difference
import struct
decode_hex = struct.pack("f", fail["decode_logprob"]).hex()
prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex()
print(f" Decode logprob (hex): 0x{decode_hex}")
print(f" Prefill logprob (hex): 0x{prefill_hex}")
# If we have all tokens/logprobs, show the context
if "decode_all_tokens" in fail and "decode_all_logprobs" in fail:
token_idx = fail["token_idx"]
all_tokens = fail["decode_all_tokens"]
all_logprobs = fail["decode_all_logprobs"]
# Show context: 2 tokens before and after
start = max(0, token_idx - 2)
end = min(len(all_tokens), token_idx + 3)
print(f" Context (tokens {start} to {end - 1}):")
for j in range(start, end):
marker = " <-- MISMATCH" if j == token_idx else ""
print(
f" [{j}] token={all_tokens[j]}, "
f"logprob={all_logprobs[j]:.8f}{marker}"
)
if len(failures) > 5:
print(f"\n ... and {len(failures) - 5} more failures for this prompt")
print(f"\n{'=' * 80}\n")
pytest.fail(
f"Decode logprobs do not match prefill logprobs: "
f"{len(failed_comparisons)} mismatches found."
)
else:
print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!")
print(f"{'=' * 80}\n")
def LLM_with_max_seqs(
@ -284,7 +988,6 @@ def LLM_with_max_seqs(
max_num_seqs: int,
gpu_memory_utilization: float,
max_model_len: int,
swap_space: int,
) -> LLM:
"""
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
@ -293,17 +996,10 @@ def LLM_with_max_seqs(
return LLM(
model=model,
max_num_seqs=max_num_seqs,
# Constrain GPU memory pool so test can run even on busy GPUs.
gpu_memory_utilization=gpu_memory_utilization,
# Keep KV cache footprint small while allowing longer outputs.
max_model_len=max_model_len,
# Allow some CPU offload if needed.
swap_space=swap_space,
# Keep things lean and CI-friendly.
dtype="auto",
# Single-GPU by default; override externally if desired.
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
enable_prefix_caching=False,
enforce_eager=True,
# Enable for MOE models

View File

@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test batch-invariant RMS normalization against standard implementations.
This test compares the Triton-based batch-invariant RMS norm implementation
with the standard CUDA-based implementation to ensure numerical accuracy.
"""
import pytest
import torch
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
def test_rms_norm_batch_invariant_vs_standard(
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
):
"""
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
Tests that the Triton-based batch-invariant RMS norm produces numerically
equivalent results to the standard CUDA implementation across various
configurations.
"""
device = torch.device("cuda")
# Create test input and weight
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation (CUDA ops)
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation (Triton)
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
if dtype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
else:
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for batch_size={batch_size}, "
f"hidden_size={hidden_size}, "
f"dtype={dtype}, eps={eps}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096])
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
"""
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
torch.manual_seed(42)
input_tensor = torch.randn(
batch_size, seq_len, hidden_size, dtype=dtype, device=device
)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
f"seq_len={seq_len}, hidden_size={hidden_size}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_numerical_stability():
"""
Test RMS norm numerical stability with extreme values.
Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf.
"""
device = torch.device("cuda")
dtype = torch.float16
eps = 1e-6
hidden_size = 2048
# Test cases with extreme values
test_cases = [
# Very small values
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
# Very large values
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
# Mixed small and large
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
# Values near zero
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
]
weight = torch.ones(hidden_size, dtype=dtype, device=device)
for idx, input_tensor in enumerate(test_cases):
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Check for NaN or Inf
assert not torch.isnan(standard_output).any(), (
f"Standard RMS norm produced NaN for test case {idx}"
)
assert not torch.isinf(standard_output).any(), (
f"Standard RMS norm produced Inf for test case {idx}"
)
assert not torch.isnan(triton_output).any(), (
f"Triton RMS norm produced NaN for test case {idx}"
)
assert not torch.isinf(triton_output).any(), (
f"Triton RMS norm produced Inf for test case {idx}"
)
# Compare outputs - very lenient for extreme values with float16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=2e-1, # 20% tolerance for extreme values
atol=2e-1,
msg=f"RMS norm mismatch for extreme value test case {idx}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_formula():
"""
Test that RMS norm follows the correct mathematical formula.
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
"""
device = torch.device("cuda")
dtype = torch.float32 # Use float32 for higher precision in formula check
eps = 1e-6
hidden_size = 1024
torch.manual_seed(42)
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Compute expected output using the formula
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Compare against formula
torch.testing.assert_close(
triton_output,
expected_output,
rtol=1e-4,
atol=1e-4,
msg="Triton RMS norm doesn't match expected formula",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int):
"""
Test RMS norm with various hidden sizes to ensure block size handling.
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
batch_size = 16
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_determinism():
"""
Test that batch-invariant RMS norm produces deterministic results.
Runs the same input through the kernel multiple times and verifies
identical outputs.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
hidden_size = 4096
batch_size = 32
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Run multiple times
outputs = []
for _ in range(5):
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
outputs.append(output)
# All outputs should be identical
reference = outputs[0]
for idx, output in enumerate(outputs[1:], start=1):
torch.testing.assert_close(
output,
reference,
rtol=0.0,
atol=0.0,
msg=f"RMS norm not deterministic: run {idx} differs from reference",
)
if __name__ == "__main__":
# Run a quick smoke test
print("Running quick smoke test of RMS norm implementations...")
device = torch.device("cuda")
batch_size = 8
hidden_size = 4096
dtype = torch.bfloat16
eps = 1e-6
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)
# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
# Compare
max_diff = (triton_output - standard_output).abs().max().item()
mean_diff = (triton_output - standard_output).abs().mean().item()
print(f"Max difference: {max_diff:.6e}")
print(f"Mean difference: {mean_diff:.6e}")
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
print(f"Triton output sample: {triton_output[0, :5].tolist()}")
if max_diff < 1e-3:
print("✓ Smoke test passed!")
else:
print("✗ Smoke test failed - differences too large")

View File

@ -3,6 +3,7 @@
import hashlib
import inspect
import os
import pickle
from unittest.mock import patch
@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
)
file_contents = {}
for filepath in files:
if filepath == "<string>":
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
if not os.path.isfile(filepath):
file_contents[filepath] = ""
else:
with open(filepath) as f:

View File

@ -20,6 +20,9 @@ from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat,
@ -419,6 +422,10 @@ class ModelConfig:
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
) -> None:
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
self.enforce_eager = True
# Set the default seed to 0 in V1.
# NOTE(woosuk): In V0, we set the default seed to None because the
# driver worker shares the same process as the user process, and thus

View File

@ -14,6 +14,9 @@ from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
@ -560,7 +563,10 @@ class ParallelConfig:
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
self.disable_custom_all_reduce = True
if (
self.distributed_executor_backend is not None

View File

@ -19,6 +19,9 @@ import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.utils import cuda_device_count_stateless, update_environment_variables
logger = init_logger(__name__)
@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
is_symmetric_memory_enabled,
)
if vllm_kernel_override_batch_invariant():
return False
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:

View File

@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
try:
@ -100,6 +103,8 @@ class SymmMemCommunicator:
return
self.force_multimem = force_multimem
self.disabled = False
if vllm_kernel_override_batch_invariant():
self.disabled = True
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:

View File

@ -1694,7 +1694,7 @@ class EngineArgs:
) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching
# V1 uses chunked prefills and prefix caching by default
# for non-pooling tasks.
# For pooling tasks the default is False
if model_config.runner_type != "pooling":

View File

@ -395,7 +395,6 @@ def mean_dim(
Tensor with mean values along specified dimension
"""
# Validate inputs
assert input.is_cuda, "Input must be a CUDA tensor"
assert -input.ndim <= dim < input.ndim, (
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
)
@ -470,6 +469,45 @@ def mm_batch_invariant(a, b):
return matmul_persistent(a, b)
def matmul_batch_invariant(a, b, *, out=None):
# torch.matmul can handle various dimensions
# For 2D x 2D, it's the same as mm
if a.ndim == 2 and b.ndim == 2:
result = matmul_persistent(a, b)
if out is not None:
out.copy_(result)
return out
return result
elif a.ndim == 3 and b.ndim == 3:
# Handle batched case like bmm
return bmm_batch_invariant(a, b, out=out)
else:
raise ValueError(
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
f"got shapes {a.shape} and {b.shape}"
)
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}"
)
def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)
@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
return log_softmax(input, dim=dim)
def softmax_batch_invariant(input, dim, dtype=None):
# Compute softmax in a deterministic way
# First subtract max for numerical stability (standard practice)
input_max = torch.amax(input, dim=dim, keepdim=True)
input = input - input_max
exp_x = torch.exp(input)
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
return exp_x / sum_exp_x
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
result = input.to(torch.float32)
if len(dim) == 0:
dim = [i for i in range(len(input.shape))]
# Sort dimensions to reduce from largest to smallest to handle shifting dims
# during iterative reduction.
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
return result
@triton.jit
def _rms_norm_kernel(
input_ptr,
weight_ptr,
output_ptr,
input_row_stride,
output_row_stride,
n_cols,
eps,
BLOCK_SIZE: tl.constexpr,
):
"""
Compute RMS normalization along the last dimension of a 2D tensor.
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
Each block handles one row of the input tensor.
"""
row_idx = tl.program_id(0).to(tl.int64)
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# Step 1: Compute sum of squares in float32 to avoid overflow
sum_sq = tl.zeros([1], dtype=tl.float32)
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
# Convert to float32 for accumulation to prevent overflow
vals_f32 = vals.to(tl.float32)
sq_vals = vals_f32 * vals_f32
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
# Step 2: Compute RMS (root mean square) in float32
mean_sq = sum_sq / n_cols
rms = tl.sqrt(mean_sq + eps)
inv_rms = 1.0 / rms
# Step 3: Normalize and apply weight
for col_offset in range(0, n_cols, BLOCK_SIZE):
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
mask = col_idx < n_cols
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
# Compute in float32 then convert back to input dtype
vals_f32 = vals.to(tl.float32)
weight_f32 = weight.to(tl.float32)
output_f32 = vals_f32 * inv_rms * weight_f32
output = output_f32.to(vals.dtype)
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
def rms_norm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Compute RMS normalization using Triton kernel.
RMS Norm normalizes the input by the root mean square and scales by weight:
output = input / sqrt(mean(input^2) + eps) * weight
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
Tensor with RMS normalization applied along the last dimension
"""
assert weight.dim() == 1, "Weight must be 1-dimensional"
assert input.shape[-1] == weight.shape[0], (
f"Input last dimension ({input.shape[-1]}) must match "
f"weight dimension ({weight.shape[0]})"
)
# Flatten all dimensions except the last one
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1])
input_2d = input_2d.contiguous()
weight = weight.contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d)
BLOCK_SIZE = 1024
grid = (n_rows,)
_rms_norm_kernel[grid](
input_2d,
weight,
output,
input_2d.stride(0),
output.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output.reshape(original_shape)
def rms_norm_batch_invariant(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""
Batch-invariant wrapper for RMS normalization.
This function provides a deterministic, batch-invariant implementation
of RMS normalization for use with the batch_invariant mode.
Args:
input: Input tensor of shape (..., hidden_size)
weight: Weight tensor of shape (hidden_size,)
eps: Small constant for numerical stability
Returns:
RMS normalized tensor
"""
return rms_norm(input, weight, eps=eps)
def linear_batch_invariant(input, weight, bias=None):
output = mm_batch_invariant(input, weight.t())
if bias is not None:
output = output + bias
return output
_batch_invariant_MODE = False
_batch_invariant_LIB = None
_original_torch_bmm = None
def is_batch_invariant_mode_enabled():
@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled():
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_MODE:
return
@ -517,16 +694,28 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
_batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
if _original_torch_bmm is not None:
torch.bmm = _original_torch_bmm
_original_torch_bmm = None
_batch_invariant_MODE = False
_batch_invariant_LIB = None
@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant():
return is_overridden
def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = [
"FLASH_ATTN", # best supported backend
"FLEX_ATTENTION",
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
# Not yet supported MLA backends
# "FLASHMLA",
# "FLASHINFER_MLA",
]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
)
logger.warning_once(warning)
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# NCCL determinism settings
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
os.environ["NCCL_COLLNET_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
os.environ["NCCL_MIN_NCHANNELS"] = "1"
os.environ["NCCL_MAX_NCHANNELS"] = "1"
os.environ["NCCL_PROTO"] = "Simple"
os.environ["NCCL_ALGO"] = "allreduce:tree"
os.environ["NCCL_NTHREADS"] = "1"
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
override_envs_for_invariance()
enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

View File

@ -15,6 +15,9 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
@ -837,6 +840,10 @@ def get_moe_configs(
be picked and the associated configuration chosen to invoke the kernel.
"""
# Avoid optimizing for the batch invariant case. Use default config
if vllm_kernel_override_batch_invariant():
return None
# First look up if an optimized configuration is available in the configs
# directory
block_shape = [block_n, block_k] if block_n and block_k else None
@ -969,6 +976,15 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_kernel_override_batch_invariant():
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
return config
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
# BLOCK_SIZE_K must be divisible by block_shape[1]
@ -1118,7 +1134,10 @@ def fused_topk_bias(
scores_for_choice = scores.view(
-1, n_routed_experts
) + e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@ -1179,7 +1198,10 @@ def grouped_topk(
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
@ -1192,11 +1214,13 @@ def grouped_topk(
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights, topk_ids = torch.topk(
tmp_scores, k=topk, dim=-1, sorted=use_sorted
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

View File

@ -8,6 +8,10 @@ import torch.nn.functional as F
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@ -21,6 +25,8 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
out,
@ -39,6 +45,10 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
ops.fused_add_rms_norm(
x,
residual,

View File

@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
k_pe,
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
)
return self.o_proj(attn_out)[0]
def forward_cuda(self, *args, **kwargs):

View File

@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEActivationFormat,
@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_kernel_override_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# If batch invariant mode is enabled, dequantize and use BF16 compute
if vllm_kernel_override_batch_invariant():
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
# Handle different quantization granularities
if self.block_quant:
# Block-wise quantization:
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
assert self.weight_block_size is not None
block_n, block_k = self.weight_block_size # Note: order is [N, K]
N, K = weight_fp8.shape
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
# We need to transpose it to [num_blocks_n, num_blocks_k] first
weight_scale = weight_scale.t()
# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
scale_expanded = weight_scale.repeat_interleave(
block_n, dim=0
).repeat_interleave(block_k, dim=1)
# Trim to exact weight size (in case of padding)
scale_expanded = scale_expanded[:N, :K]
weight_bf16 = weight_fp8 * scale_expanded
else:
# Per-tensor quantization: weight IS transposed to [K, N]
# scale should be scalar or [1] or per-output-channel [N]
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
# For block quant, weight is [N, K], for per-tensor it's [K, N]
# F.linear expects weight to be [N, K], so:
if self.block_quant:
# Already in correct shape [N, K]
output = torch.nn.functional.linear(x, weight_bf16, bias)
else:
# Need to transpose back: [K, N] -> [N, K]
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
return output
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,

View File

@ -216,6 +216,7 @@ class TransformerBlock(torch.nn.Module):
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.attn(hidden_states, positions)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
output = self.mlp(hidden_states)

View File

@ -31,10 +31,12 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
@ -306,6 +308,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_kernel_override_batch_invariant():
max_num_splits = 1
def schedule(
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
@ -478,6 +483,9 @@ class FlashAttentionImpl(AttentionImpl):
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device."
@ -810,6 +818,7 @@ class FlashAttentionImpl(AttentionImpl):
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0,
)
return output
@ -954,6 +963,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@ -978,6 +988,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
)
# Merge prefix and suffix outputs, and store the result in output.

View File

@ -211,6 +211,9 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
@ -1187,6 +1190,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled():
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
x = aiter_triton_fp8_bmm(
@ -1279,6 +1283,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if vllm_kernel_override_batch_invariant():
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(
q=q,
@ -1841,9 +1847,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if has_decode:
assert attn_metadata.decode is not None
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
@ -1868,17 +1876,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# Pads the head_dim if necessary (for the underlying kernel)
N, B, P = decode_q_nope.shape
_, _, L = self.W_UK_T.shape
if self.q_pad_num_heads is not None:
decode_ql_nope = decode_q_nope.new_empty(
(self.q_pad_num_heads, B, L)
)
decode_ql_nope.resize_((N, B, L))
else:
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)

View File

@ -18,6 +18,9 @@ from vllm.attention.utils.fa_utils import (
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
@ -107,6 +110,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# pre-allocated during capture.
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
if vllm_kernel_override_batch_invariant():
self.max_num_splits = 1
def _schedule_decode(
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
@ -175,7 +181,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
return FlashAttnMLADecodeMetadata(
if vllm_kernel_override_batch_invariant():
max_num_splits = 1
metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
@ -185,6 +194,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
max_num_splits=max_num_splits,
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
)
return metadata
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):

View File

@ -14,6 +14,9 @@ from vllm.attention.ops.flashmla import (
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
@ -223,19 +226,50 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if type(q) is tuple:
q = torch.cat(q, dim=-1)
# mypy assertion: q is now always a tensor
assert isinstance(q, torch.Tensor)
num_decodes = attn_metadata.num_decodes
q = reshape_query_for_spec_decode(q, num_decodes)
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
num_splits = attn_metadata.decode.num_splits
if vllm_kernel_override_batch_invariant():
device = q.device
dtype = torch.int32
B = q.shape[0]
# block_table shape: [batch_size, max_num_blocks_per_seq]
# The number of blocks per sequence is in the second dimension
topk = attn_metadata.decode.block_table.shape[-1]
B_TOPK = 64
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
end_block_idx = topk // B_TOPK
# Single partition => num_sm_parts = 1
# TileSchedulerMetaDataSize = 8, layout:
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
# begin_n_split_idx, _, _, _]
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
tile_scheduler_metadata[0, 0] = 0 # begin_idx
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
tile_scheduler_metadata[0, 3] = end_block_idx
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
# fields [5..7] stay 0
# Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),

View File

@ -13,6 +13,9 @@ from vllm.attention.backends.abstract import (
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import (
@ -158,7 +161,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
)
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
num_kv_splits = 4 # TODO: heuristic
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(

View File

@ -231,9 +231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance()
model_config = self.model_config
cache_config = self.cache_config

View File

@ -764,6 +764,9 @@ def init_worker_distributed_environment(
) -> None:
"""Initialize the distributed environment."""
parallel_config = vllm_config.parallel_config
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance()
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(