mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:55:01 +08:00
Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
785d8b6410
commit
7d8975de84
@ -3,56 +3,73 @@
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode():
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
|
||||
yield
|
||||
# Restore original value after test
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Lightweight random prompt generator to vary prompt lengths and content.
|
||||
vocab = [
|
||||
"alpha",
|
||||
"bravo",
|
||||
"charlie",
|
||||
"delta",
|
||||
"echo",
|
||||
"foxtrot",
|
||||
"golf",
|
||||
"hotel",
|
||||
"india",
|
||||
"juliet",
|
||||
"kilo",
|
||||
"lima",
|
||||
"mike",
|
||||
"november",
|
||||
"oscar",
|
||||
"papa",
|
||||
"quebec",
|
||||
"romeo",
|
||||
"sierra",
|
||||
"tango",
|
||||
"uniform",
|
||||
"victor",
|
||||
"whiskey",
|
||||
"xray",
|
||||
"yankee",
|
||||
"zulu",
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
# Use a mix of common English text patterns
|
||||
|
||||
prompt_templates = [
|
||||
# Question-answer style
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
# Story/narrative style
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
# Technical/code style
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
# Factual/informative style
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
# Conversational style
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
n = random.randint(min_words, max_words)
|
||||
words = random.choices(vocab, k=n)
|
||||
|
||||
# Add some noise and punctuation variability
|
||||
if random.random() < 0.5:
|
||||
words[0] = words[0].capitalize()
|
||||
if random.random() < 0.2:
|
||||
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
|
||||
punct = random.choice([".", "?", "!", "...", ""])
|
||||
return " ".join(words) + punct
|
||||
# Pick a random template
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
# Add some padding to vary the length if needed
|
||||
if min_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
* (min_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
"""
|
||||
@ -91,7 +108,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
|
||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
|
||||
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
||||
|
||||
# Sampling parameters: longer outputs with a more random-sounding
|
||||
# continuation,but still deterministic due to fixed seed.
|
||||
@ -117,7 +133,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
@ -132,7 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
max_num_seqs=max_batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
mismatches = 0
|
||||
@ -195,16 +209,21 @@ def _extract_step_logprobs(request_output):
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t
|
||||
return t, inner.token_ids
|
||||
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
@ -214,69 +233,754 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# Force float32 to avoid precision-induced differences.
|
||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||
# all-reduce operations (custom all-reduce may not be deterministic)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
|
||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16", # not everything is supported
|
||||
)
|
||||
|
||||
prompts = [_random_prompt(10, 1024) for i in range(100)]
|
||||
# Use more realistic prompts for better token generation
|
||||
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
print("\n" + "=" * 80)
|
||||
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
bs1_logprobs_per_prompt = []
|
||||
for p in prompts:
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs = _extract_step_logprobs(outs[0])
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip(
|
||||
"Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
|
||||
# BS=N: run prompts in a batch and collect logprobs per step for each
|
||||
# prompt.
|
||||
print("\n" + "=" * 80)
|
||||
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bsN_logprobs_per_prompt = []
|
||||
for o in outs_batched:
|
||||
step_logprobs = _extract_step_logprobs(o)
|
||||
bsN_tokens_per_prompt = []
|
||||
|
||||
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||
for idx, o in enumerate(outs_batched):
|
||||
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip(
|
||||
"Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
|
||||
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
|
||||
):
|
||||
assert len(logprobs_bs1) == len(logprobs_bsN), (
|
||||
f"Different number of generation steps for prompt index {i}: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
||||
failed_prompts = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)
|
||||
):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (
|
||||
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)"
|
||||
)
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
"bs1_all_logprobs": [
|
||||
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
|
||||
],
|
||||
"bsN_all_logprobs": [
|
||||
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
|
||||
],
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
assert a.shape == b.shape, (
|
||||
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
|
||||
if a.shape != b.shape:
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
# Print which token failed
|
||||
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
failed_prompts.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
"bs1_all_logprobs": [
|
||||
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
|
||||
],
|
||||
"bsN_all_logprobs": [
|
||||
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
|
||||
],
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
# Print summary of all failures
|
||||
if failed_prompts:
|
||||
print(f"\n{'=' * 80}")
|
||||
fail_msg = (
|
||||
f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts failed"
|
||||
)
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for fail in failed_prompts:
|
||||
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
|
||||
print(f" Reason: {fail['reason']}")
|
||||
print(f" Preview: {fail['prompt_preview']}...")
|
||||
|
||||
# Always show the tokens
|
||||
if "bs1_tokens" in fail:
|
||||
print(f" BS=1 tokens: {fail['bs1_tokens']}")
|
||||
if "bsN_tokens" in fail:
|
||||
print(f" BS=N tokens: {fail['bsN_tokens']}")
|
||||
|
||||
if "bs1_all_logprobs" in fail:
|
||||
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
|
||||
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
|
||||
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
|
||||
print(f" Step {step_idx}: {logprobs}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Fail the test with summary
|
||||
msg = (
|
||||
f"Batch invariance violated in {len(failed_prompts)}/"
|
||||
f"{len(prompts)} prompts. See output above for details."
|
||||
)
|
||||
pytest.fail(msg)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
def test_simple_generation():
|
||||
"""
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
Useful for quick smoke testing and debugging.
|
||||
"""
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
max_num_seqs=1,
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
max_model_len=2048,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=False,
|
||||
)
|
||||
|
||||
prompt = "the capital of france is"
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("Running simple generation test")
|
||||
print(f"Prompt: '{prompt}'")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
try:
|
||||
outputs = llm.generate([prompt], sampling_params)
|
||||
|
||||
assert len(outputs) == 1
|
||||
output_text = outputs[0].outputs[0].text
|
||||
|
||||
print(f"Output: '{output_text}'")
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Full completion: '{prompt}{output_text}'")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
finally:
|
||||
with contextlib.suppress(Exception):
|
||||
llm.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
@pytest.mark.forked
|
||||
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||
"""
|
||||
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||
between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
|
||||
doing something useful.
|
||||
|
||||
The test will PASS if we detect differences (proving batch invariance matters).
|
||||
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||
"""
|
||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
# CRITICAL: Disable batch invariance for this test
|
||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
|
||||
|
||||
try:
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
|
||||
# Use more realistic prompts for better token generation
|
||||
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
print("\n" + "=" * 80)
|
||||
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
bs1_logprobs_per_prompt = []
|
||||
bs1_tokens_per_prompt = []
|
||||
for idx, p in enumerate(prompts):
|
||||
print(
|
||||
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||
)
|
||||
# Bitwise exact equality.
|
||||
assert torch.equal(a, b), (
|
||||
f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||
f"(dtype={a.dtype}, shape={a.shape})."
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip(
|
||||
"Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
bs1_tokens_per_prompt.append(token_ids)
|
||||
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||
|
||||
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
|
||||
print("\n" + "=" * 80)
|
||||
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bsN_logprobs_per_prompt = []
|
||||
bsN_tokens_per_prompt = []
|
||||
|
||||
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||
for idx, o in enumerate(outs_batched):
|
||||
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip(
|
||||
"Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||
bsN_tokens_per_prompt.append(token_ids)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||
differences_found = []
|
||||
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
)
|
||||
):
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
reason = (
|
||||
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||
f"vs {len(logprobs_bsN)} (BS=N)"
|
||||
)
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "all",
|
||||
"reason": reason,
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if tokens match first
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": "sampling",
|
||||
"reason": "Different tokens sampled",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a.shape != b.shape:
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
if not torch.equal(a, b):
|
||||
max_diff = torch.abs(a - b).max().item()
|
||||
print(
|
||||
f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
||||
f"Token {t}: max_diff={max_diff:.6e}"
|
||||
)
|
||||
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||
print(f" BS=1 logprob: {a.tolist()}")
|
||||
print(f" BS=N logprob: {b.tolist()}")
|
||||
differences_found.append(
|
||||
{
|
||||
"prompt_idx": i,
|
||||
"step": t,
|
||||
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||
"prompt_preview": prompts[i][:100],
|
||||
"bs1_tokens": tokens_bs1,
|
||||
"bsN_tokens": tokens_bsN,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'=' * 80}")
|
||||
if differences_found:
|
||||
success_msg = (
|
||||
f"✓ SUCCESS: Batch invariance is doing something! "
|
||||
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
||||
f"with differences when batch invariance was DISABLED."
|
||||
)
|
||||
print(success_msg)
|
||||
print(f"{'=' * 80}")
|
||||
for diff in differences_found:
|
||||
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
|
||||
print(f" Reason: {diff['reason']}")
|
||||
print(f" Preview: {diff['prompt_preview']}...")
|
||||
if "bs1_tokens" in diff:
|
||||
print(f" BS=1 tokens: {diff['bs1_tokens']}")
|
||||
if "bsN_tokens" in diff:
|
||||
print(f" BS=N tokens: {diff['bsN_tokens']}")
|
||||
print(f"{'=' * 80}\n")
|
||||
# Test PASSES because we found differences (batch invariance matters!)
|
||||
return
|
||||
else:
|
||||
# Test FAILS because everything matched even without batch invariance
|
||||
fail_msg = (
|
||||
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
||||
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
||||
f"This suggests batch invariance might not be necessary, "
|
||||
f"or the test needs more sensitive prompts."
|
||||
)
|
||||
print(fail_msg)
|
||||
print(f"{'=' * 80}\n")
|
||||
pytest.fail(fail_msg)
|
||||
|
||||
finally:
|
||||
# Restore original value
|
||||
if old_value is None:
|
||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||
else:
|
||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.forked
|
||||
def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||
"""
|
||||
Test that verifies decode logprobs match prefill logprobs.
|
||||
|
||||
For each decoded token at position i:
|
||||
1. Run decode to generate N tokens and collect their logprobs
|
||||
2. For each position i in [0, N):
|
||||
- Take prefix = prompt + tokens[0:i]
|
||||
- Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
|
||||
- Verify prefill logprob matches decode logprob bitwise
|
||||
|
||||
This ensures that the logprobs from decode are consistent with what
|
||||
we would get if we ran prefill on each prefix.
|
||||
"""
|
||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
|
||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||
|
||||
if disable_custom_ar:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_prefix_caching=False,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
)
|
||||
|
||||
# Use a few test prompts
|
||||
num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4"))
|
||||
prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)]
|
||||
|
||||
# Generate longer sequences to test multiple decode steps
|
||||
max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16"))
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.0, # Greedy for determinism
|
||||
max_tokens=max_tokens,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("STEP 1: Running decode to generate tokens and collect logprobs")
|
||||
print("=" * 80 + "\n")
|
||||
|
||||
# Step 1: Run decode and collect logprobs
|
||||
decode_outputs = llm.generate(prompts, sp, use_tqdm=False)
|
||||
|
||||
failed_comparisons = []
|
||||
|
||||
for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)):
|
||||
print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...")
|
||||
|
||||
# Extract decode logprobs and tokens
|
||||
decode_logprobs, token_ids = _extract_step_logprobs(decode_output)
|
||||
if decode_logprobs is None:
|
||||
pytest.skip(
|
||||
"Logprobs are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test."
|
||||
)
|
||||
|
||||
print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}")
|
||||
print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}")
|
||||
|
||||
# Step 2: For each token position, run prefill and compare
|
||||
print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...")
|
||||
|
||||
for token_idx in range(len(token_ids)):
|
||||
# Construct the prefix up to (but not including) this token
|
||||
current_token = token_ids[token_idx]
|
||||
|
||||
# We need to detokenize to get the text prefix
|
||||
# For this, we'll use the tokenizer from the LLM
|
||||
# However, the LLM API doesn't expose tokenizer easily, so we'll
|
||||
# construct the prefix by decoding from the original prompt
|
||||
|
||||
# Get text up to this point by using the output text
|
||||
# This is approximate but should work for verification
|
||||
if token_idx == 0:
|
||||
prefix_prompt = prompt
|
||||
else:
|
||||
# Use the partial output text up to this token
|
||||
# We'll need to construct this from the full output
|
||||
prefix_output = decode_output.outputs[0]
|
||||
# Get the text for tokens 0 to token_idx-1
|
||||
# Unfortunately, we don't have per-token text, so we'll use
|
||||
# a different approach: run prefill with prompt + tokens[0:token_idx]
|
||||
|
||||
# Actually, we need to get the actual text. Let's use a workaround:
|
||||
# Run a generation with max_tokens = token_idx to get that prefix
|
||||
prefix_sp = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=token_idx,
|
||||
logprobs=1,
|
||||
)
|
||||
prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0]
|
||||
prefix_prompt = prompt + prefix_output.outputs[0].text
|
||||
|
||||
# Now run prefill with max_tokens=1 to get the logprob of the next token
|
||||
prefill_sp = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=1,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
print(
|
||||
f" [Token {token_idx}] Running prefill for prefix "
|
||||
f"(len={len(prefix_prompt)})..."
|
||||
)
|
||||
prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[
|
||||
0
|
||||
]
|
||||
prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output)
|
||||
|
||||
if prefill_logprobs is None:
|
||||
print(f" [Token {token_idx}] Warning: No prefill logprobs available")
|
||||
continue
|
||||
|
||||
# The first token from prefill should match the current token
|
||||
prefill_token = prefill_token_ids[0]
|
||||
prefill_logprob = prefill_logprobs[0].item()
|
||||
decode_logprob = decode_logprobs[token_idx].item()
|
||||
|
||||
print(
|
||||
f" [Token {token_idx}] Decode token: {current_token}, "
|
||||
f"logprob: {decode_logprob:.8f}"
|
||||
)
|
||||
print(
|
||||
f" [Token {token_idx}] Prefill token: {prefill_token}, "
|
||||
f"logprob: {prefill_logprob:.8f}"
|
||||
)
|
||||
|
||||
# Check if tokens match
|
||||
if current_token != prefill_token:
|
||||
failed_comparisons.append(
|
||||
{
|
||||
"prompt_idx": prompt_idx,
|
||||
"token_idx": token_idx,
|
||||
"reason": "Token mismatch",
|
||||
"decode_token": current_token,
|
||||
"prefill_token": prefill_token,
|
||||
"decode_logprob": decode_logprob,
|
||||
"prefill_logprob": prefill_logprob,
|
||||
"prompt_text": prompt[:100],
|
||||
"prefix_text": prefix_prompt[:100],
|
||||
}
|
||||
)
|
||||
print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!")
|
||||
continue
|
||||
|
||||
# Check if logprobs match bitwise
|
||||
if decode_logprob != prefill_logprob:
|
||||
diff = abs(decode_logprob - prefill_logprob)
|
||||
failed_comparisons.append(
|
||||
{
|
||||
"prompt_idx": prompt_idx,
|
||||
"token_idx": token_idx,
|
||||
"reason": "Logprob mismatch",
|
||||
"decode_token": current_token,
|
||||
"prefill_token": prefill_token,
|
||||
"decode_logprob": decode_logprob,
|
||||
"prefill_logprob": prefill_logprob,
|
||||
"diff": diff,
|
||||
"prompt_text": prompt[:100],
|
||||
"prefix_text": prefix_prompt[:100],
|
||||
"decode_all_tokens": token_ids,
|
||||
"decode_all_logprobs": decode_logprobs.tolist(),
|
||||
}
|
||||
)
|
||||
print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}")
|
||||
else:
|
||||
print(f" [Token {token_idx}] ✓ Match (bitwise equal)")
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'=' * 80}")
|
||||
if failed_comparisons:
|
||||
print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Group failures by prompt for better readability
|
||||
failures_by_prompt: dict[int, list[dict]] = {}
|
||||
for fail in failed_comparisons:
|
||||
pid = fail["prompt_idx"]
|
||||
if pid not in failures_by_prompt:
|
||||
failures_by_prompt[pid] = []
|
||||
failures_by_prompt[pid].append(fail)
|
||||
|
||||
for prompt_idx, failures in failures_by_prompt.items():
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...")
|
||||
print(f"{'=' * 80}")
|
||||
print(f"Total failures for this prompt: {len(failures)}")
|
||||
|
||||
# Show where mismatches occur (which token positions)
|
||||
mismatch_positions = [f["token_idx"] for f in failures]
|
||||
print(f"Mismatch at token positions: {mismatch_positions}")
|
||||
|
||||
# Show first few failures in detail
|
||||
for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt
|
||||
print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:")
|
||||
print(f" Reason: {fail['reason']}")
|
||||
print(f" Prefix text: '{fail['prefix_text']}...'")
|
||||
print(
|
||||
f" Decode: token={fail['decode_token']}, "
|
||||
f"logprob={fail['decode_logprob']:.10f}"
|
||||
)
|
||||
print(
|
||||
f" Prefill: token={fail['prefill_token']}, "
|
||||
f"logprob={fail['prefill_logprob']:.10f}"
|
||||
)
|
||||
if "diff" in fail:
|
||||
print(f" Difference: {fail['diff']:.10e}")
|
||||
# Show in hex to see bitwise difference
|
||||
import struct
|
||||
|
||||
decode_hex = struct.pack("f", fail["decode_logprob"]).hex()
|
||||
prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex()
|
||||
print(f" Decode logprob (hex): 0x{decode_hex}")
|
||||
print(f" Prefill logprob (hex): 0x{prefill_hex}")
|
||||
|
||||
# If we have all tokens/logprobs, show the context
|
||||
if "decode_all_tokens" in fail and "decode_all_logprobs" in fail:
|
||||
token_idx = fail["token_idx"]
|
||||
all_tokens = fail["decode_all_tokens"]
|
||||
all_logprobs = fail["decode_all_logprobs"]
|
||||
|
||||
# Show context: 2 tokens before and after
|
||||
start = max(0, token_idx - 2)
|
||||
end = min(len(all_tokens), token_idx + 3)
|
||||
|
||||
print(f" Context (tokens {start} to {end - 1}):")
|
||||
for j in range(start, end):
|
||||
marker = " <-- MISMATCH" if j == token_idx else ""
|
||||
print(
|
||||
f" [{j}] token={all_tokens[j]}, "
|
||||
f"logprob={all_logprobs[j]:.8f}{marker}"
|
||||
)
|
||||
|
||||
if len(failures) > 5:
|
||||
print(f"\n ... and {len(failures) - 5} more failures for this prompt")
|
||||
|
||||
print(f"\n{'=' * 80}\n")
|
||||
|
||||
pytest.fail(
|
||||
f"Decode logprobs do not match prefill logprobs: "
|
||||
f"{len(failed_comparisons)} mismatches found."
|
||||
)
|
||||
else:
|
||||
print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
|
||||
def LLM_with_max_seqs(
|
||||
@ -284,7 +988,6 @@ def LLM_with_max_seqs(
|
||||
max_num_seqs: int,
|
||||
gpu_memory_utilization: float,
|
||||
max_model_len: int,
|
||||
swap_space: int,
|
||||
) -> LLM:
|
||||
"""
|
||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||
@ -293,17 +996,10 @@ def LLM_with_max_seqs(
|
||||
return LLM(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
# Constrain GPU memory pool so test can run even on busy GPUs.
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
# Keep KV cache footprint small while allowing longer outputs.
|
||||
max_model_len=max_model_len,
|
||||
# Allow some CPU offload if needed.
|
||||
swap_space=swap_space,
|
||||
# Keep things lean and CI-friendly.
|
||||
dtype="auto",
|
||||
# Single-GPU by default; override externally if desired.
|
||||
dtype="bfloat16",
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
||||
enable_prefix_caching=False,
|
||||
enforce_eager=True,
|
||||
# Enable for MOE models
|
||||
|
||||
346
tests/v1/generation/test_rms_norm_batch_invariant.py
Normal file
346
tests/v1/generation/test_rms_norm_batch_invariant.py
Normal file
@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test batch-invariant RMS normalization against standard implementations.
|
||||
|
||||
This test compares the Triton-based batch-invariant RMS norm implementation
|
||||
with the standard CUDA-based implementation to ensure numerical accuracy.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
|
||||
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
|
||||
def test_rms_norm_batch_invariant_vs_standard(
|
||||
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
|
||||
):
|
||||
"""
|
||||
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
|
||||
|
||||
Tests that the Triton-based batch-invariant RMS norm produces numerically
|
||||
equivalent results to the standard CUDA implementation across various
|
||||
configurations.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Create test input and weight
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation (CUDA ops)
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation (Triton)
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare outputs
|
||||
# Use looser tolerance for bfloat16 due to its lower precision
|
||||
if dtype == torch.bfloat16:
|
||||
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"RMS norm mismatch for batch_size={batch_size}, "
|
||||
f"hidden_size={hidden_size}, "
|
||||
f"dtype={dtype}, eps={eps}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 16, 128])
|
||||
@pytest.mark.parametrize("seq_len", [1, 32, 512])
|
||||
@pytest.mark.parametrize("hidden_size", [2048, 4096])
|
||||
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||
"""
|
||||
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
|
||||
|
||||
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
|
||||
inputs that are common in transformer models.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(
|
||||
batch_size, seq_len, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Use looser tolerance for bfloat16
|
||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
|
||||
f"seq_len={seq_len}, hidden_size={hidden_size}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
def test_rms_norm_numerical_stability():
|
||||
"""
|
||||
Test RMS norm numerical stability with extreme values.
|
||||
|
||||
Ensures that both implementations handle edge cases like very small or large
|
||||
values without producing NaN or Inf.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float16
|
||||
eps = 1e-6
|
||||
hidden_size = 2048
|
||||
|
||||
# Test cases with extreme values
|
||||
test_cases = [
|
||||
# Very small values
|
||||
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
|
||||
# Very large values
|
||||
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
|
||||
# Mixed small and large
|
||||
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
|
||||
# Values near zero
|
||||
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
|
||||
]
|
||||
|
||||
weight = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
for idx, input_tensor in enumerate(test_cases):
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Check for NaN or Inf
|
||||
assert not torch.isnan(standard_output).any(), (
|
||||
f"Standard RMS norm produced NaN for test case {idx}"
|
||||
)
|
||||
assert not torch.isinf(standard_output).any(), (
|
||||
f"Standard RMS norm produced Inf for test case {idx}"
|
||||
)
|
||||
assert not torch.isnan(triton_output).any(), (
|
||||
f"Triton RMS norm produced NaN for test case {idx}"
|
||||
)
|
||||
assert not torch.isinf(triton_output).any(), (
|
||||
f"Triton RMS norm produced Inf for test case {idx}"
|
||||
)
|
||||
|
||||
# Compare outputs - very lenient for extreme values with float16
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=2e-1, # 20% tolerance for extreme values
|
||||
atol=2e-1,
|
||||
msg=f"RMS norm mismatch for extreme value test case {idx}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
def test_rms_norm_formula():
|
||||
"""
|
||||
Test that RMS norm follows the correct mathematical formula.
|
||||
|
||||
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.float32 # Use float32 for higher precision in formula check
|
||||
eps = 1e-6
|
||||
hidden_size = 1024
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Compute expected output using the formula
|
||||
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
|
||||
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare against formula
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
expected_output,
|
||||
rtol=1e-4,
|
||||
atol=1e-4,
|
||||
msg="Triton RMS norm doesn't match expected formula",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
|
||||
def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||
"""
|
||||
Test RMS norm with various hidden sizes to ensure block size handling.
|
||||
|
||||
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
|
||||
correctly handles hidden sizes both smaller and larger than the block size.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
batch_size = 16
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Use looser tolerance for bfloat16
|
||||
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||
|
||||
torch.testing.assert_close(
|
||||
triton_output,
|
||||
standard_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||
)
|
||||
def test_rms_norm_determinism():
|
||||
"""
|
||||
Test that batch-invariant RMS norm produces deterministic results.
|
||||
|
||||
Runs the same input through the kernel multiple times and verifies
|
||||
identical outputs.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
hidden_size = 4096
|
||||
batch_size = 32
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Run multiple times
|
||||
outputs = []
|
||||
for _ in range(5):
|
||||
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
|
||||
outputs.append(output)
|
||||
|
||||
# All outputs should be identical
|
||||
reference = outputs[0]
|
||||
for idx, output in enumerate(outputs[1:], start=1):
|
||||
torch.testing.assert_close(
|
||||
output,
|
||||
reference,
|
||||
rtol=0.0,
|
||||
atol=0.0,
|
||||
msg=f"RMS norm not deterministic: run {idx} differs from reference",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run a quick smoke test
|
||||
print("Running quick smoke test of RMS norm implementations...")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 8
|
||||
hidden_size = 4096
|
||||
dtype = torch.bfloat16
|
||||
eps = 1e-6
|
||||
|
||||
torch.manual_seed(42)
|
||||
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
# Standard implementation
|
||||
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||
rms_norm_layer.weight.data = weight.clone()
|
||||
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||
|
||||
# Batch-invariant implementation
|
||||
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||
|
||||
# Compare
|
||||
max_diff = (triton_output - standard_output).abs().max().item()
|
||||
mean_diff = (triton_output - standard_output).abs().mean().item()
|
||||
|
||||
print(f"Max difference: {max_diff:.6e}")
|
||||
print(f"Mean difference: {mean_diff:.6e}")
|
||||
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
|
||||
print(f"Triton output sample: {triton_output[0, :5].tolist()}")
|
||||
|
||||
if max_diff < 1e-3:
|
||||
print("✓ Smoke test passed!")
|
||||
else:
|
||||
print("✗ Smoke test failed - differences too large")
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
|
||||
)
|
||||
file_contents = {}
|
||||
for filepath in files:
|
||||
if filepath == "<string>":
|
||||
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
|
||||
if not os.path.isfile(filepath):
|
||||
file_contents[filepath] = ""
|
||||
else:
|
||||
with open(filepath) as f:
|
||||
|
||||
@ -20,6 +20,9 @@ from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType
|
||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat,
|
||||
@ -419,6 +422,10 @@ class ModelConfig:
|
||||
skip_mm_profiling: bool | None,
|
||||
video_pruning_rate: float | None,
|
||||
) -> None:
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.enforce_eager = True
|
||||
|
||||
# Set the default seed to 0 in V1.
|
||||
# NOTE(woosuk): In V0, we set the default seed to None because the
|
||||
# driver worker shares the same process as the user process, and thus
|
||||
|
||||
@ -14,6 +14,9 @@ from typing_extensions import Self
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
||||
|
||||
@ -560,7 +563,10 @@ class ParallelConfig:
|
||||
def _verify_args(self) -> Self:
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.disable_custom_all_reduce = True
|
||||
|
||||
if (
|
||||
self.distributed_executor_backend is not None
|
||||
|
||||
@ -19,6 +19,9 @@ import torch.multiprocessing as mp
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return False
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
|
||||
@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
@ -100,6 +103,8 @@ class SymmMemCommunicator:
|
||||
return
|
||||
self.force_multimem = force_multimem
|
||||
self.disabled = False
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.disabled = True
|
||||
|
||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||
if self.disabled:
|
||||
|
||||
@ -1694,7 +1694,7 @@ class EngineArgs:
|
||||
) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
|
||||
# V1 always uses chunked prefills and prefix caching
|
||||
# V1 uses chunked prefills and prefix caching by default
|
||||
# for non-pooling tasks.
|
||||
# For pooling tasks the default is False
|
||||
if model_config.runner_type != "pooling":
|
||||
|
||||
@ -395,7 +395,6 @@ def mean_dim(
|
||||
Tensor with mean values along specified dimension
|
||||
"""
|
||||
# Validate inputs
|
||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
||||
assert -input.ndim <= dim < input.ndim, (
|
||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||
)
|
||||
@ -470,6 +469,45 @@ def mm_batch_invariant(a, b):
|
||||
return matmul_persistent(a, b)
|
||||
|
||||
|
||||
def matmul_batch_invariant(a, b, *, out=None):
|
||||
# torch.matmul can handle various dimensions
|
||||
# For 2D x 2D, it's the same as mm
|
||||
if a.ndim == 2 and b.ndim == 2:
|
||||
result = matmul_persistent(a, b)
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
elif a.ndim == 3 and b.ndim == 3:
|
||||
# Handle batched case like bmm
|
||||
return bmm_batch_invariant(a, b, out=out)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def bmm_batch_invariant(a, b, *, out=None):
|
||||
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
||||
# Process each batch separately with our persistent kernel
|
||||
if a.ndim == 3 and b.ndim == 3:
|
||||
results = []
|
||||
for i in range(a.shape[0]):
|
||||
results.append(matmul_persistent(a[i], b[i]))
|
||||
result = torch.stack(results, dim=0)
|
||||
|
||||
if out is not None:
|
||||
out.copy_(result)
|
||||
return out
|
||||
return result
|
||||
else:
|
||||
raise ValueError(
|
||||
f"bmm_batch_invariant expects 3D tensors, "
|
||||
f"got shapes {a.shape} and {b.shape}"
|
||||
)
|
||||
|
||||
|
||||
def addmm_batch_invariant(bias, a, b):
|
||||
return matmul_persistent(a, b, bias=bias)
|
||||
|
||||
@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def softmax_batch_invariant(input, dim, dtype=None):
|
||||
# Compute softmax in a deterministic way
|
||||
# First subtract max for numerical stability (standard practice)
|
||||
input_max = torch.amax(input, dim=dim, keepdim=True)
|
||||
input = input - input_max
|
||||
exp_x = torch.exp(input)
|
||||
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
|
||||
return exp_x / sum_exp_x
|
||||
|
||||
|
||||
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
if len(dim) == 0:
|
||||
dim = [i for i in range(len(input.shape))]
|
||||
|
||||
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||
# during iterative reduction.
|
||||
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||
@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
|
||||
return result
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_kernel(
|
||||
input_ptr,
|
||||
weight_ptr,
|
||||
output_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute RMS normalization along the last dimension of a 2D tensor.
|
||||
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||
Each block handles one row of the input tensor.
|
||||
"""
|
||||
row_idx = tl.program_id(0).to(tl.int64)
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# Step 1: Compute sum of squares in float32 to avoid overflow
|
||||
sum_sq = tl.zeros([1], dtype=tl.float32)
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
# Convert to float32 for accumulation to prevent overflow
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
sq_vals = vals_f32 * vals_f32
|
||||
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
|
||||
|
||||
# Step 2: Compute RMS (root mean square) in float32
|
||||
mean_sq = sum_sq / n_cols
|
||||
rms = tl.sqrt(mean_sq + eps)
|
||||
inv_rms = 1.0 / rms
|
||||
|
||||
# Step 3: Normalize and apply weight
|
||||
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_idx < n_cols
|
||||
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
|
||||
# Compute in float32 then convert back to input dtype
|
||||
vals_f32 = vals.to(tl.float32)
|
||||
weight_f32 = weight.to(tl.float32)
|
||||
output_f32 = vals_f32 * inv_rms * weight_f32
|
||||
output = output_f32.to(vals.dtype)
|
||||
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||
|
||||
|
||||
def rms_norm(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute RMS normalization using Triton kernel.
|
||||
|
||||
RMS Norm normalizes the input by the root mean square and scales by weight:
|
||||
output = input / sqrt(mean(input^2) + eps) * weight
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
Tensor with RMS normalization applied along the last dimension
|
||||
"""
|
||||
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||
assert input.shape[-1] == weight.shape[0], (
|
||||
f"Input last dimension ({input.shape[-1]}) must match "
|
||||
f"weight dimension ({weight.shape[0]})"
|
||||
)
|
||||
|
||||
# Flatten all dimensions except the last one
|
||||
original_shape = input.shape
|
||||
input_2d = input.reshape(-1, input.shape[-1])
|
||||
input_2d = input_2d.contiguous()
|
||||
weight = weight.contiguous()
|
||||
|
||||
n_rows, n_cols = input_2d.shape
|
||||
|
||||
output = torch.empty_like(input_2d)
|
||||
BLOCK_SIZE = 1024
|
||||
grid = (n_rows,)
|
||||
_rms_norm_kernel[grid](
|
||||
input_2d,
|
||||
weight,
|
||||
output,
|
||||
input_2d.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
eps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return output.reshape(original_shape)
|
||||
|
||||
|
||||
def rms_norm_batch_invariant(
|
||||
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Batch-invariant wrapper for RMS normalization.
|
||||
|
||||
This function provides a deterministic, batch-invariant implementation
|
||||
of RMS normalization for use with the batch_invariant mode.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (..., hidden_size)
|
||||
weight: Weight tensor of shape (hidden_size,)
|
||||
eps: Small constant for numerical stability
|
||||
|
||||
Returns:
|
||||
RMS normalized tensor
|
||||
"""
|
||||
return rms_norm(input, weight, eps=eps)
|
||||
|
||||
|
||||
def linear_batch_invariant(input, weight, bias=None):
|
||||
output = mm_batch_invariant(input, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
_original_torch_bmm = None
|
||||
|
||||
|
||||
def is_batch_invariant_mode_enabled():
|
||||
@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled():
|
||||
|
||||
|
||||
def enable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
if _batch_invariant_MODE:
|
||||
return
|
||||
|
||||
@ -517,16 +694,28 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||
)
|
||||
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
# Also monkeypatch torch.bmm directly as a fallback
|
||||
_original_torch_bmm = torch.bmm
|
||||
torch.bmm = bmm_batch_invariant
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
if _batch_invariant_LIB is not None:
|
||||
_batch_invariant_LIB._destroy()
|
||||
if _original_torch_bmm is not None:
|
||||
torch.bmm = _original_torch_bmm
|
||||
_original_torch_bmm = None
|
||||
_batch_invariant_MODE = False
|
||||
_batch_invariant_LIB = None
|
||||
|
||||
@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant():
|
||||
return is_overridden
|
||||
|
||||
|
||||
def override_envs_for_invariance():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = [
|
||||
"FLASH_ATTN", # best supported backend
|
||||
"FLEX_ATTENTION",
|
||||
"FLASHINFER",
|
||||
"FLASH_ATTN_MLA",
|
||||
"TRITON_MLA",
|
||||
# Not yet supported MLA backends
|
||||
# "FLASHMLA",
|
||||
# "FLASHINFER_MLA",
|
||||
]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
"Forcibly updating attention backend to"
|
||||
f" {supported_backends[0]} for batch_invariant. "
|
||||
f" Supported backends: {supported_backends}."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
|
||||
warning = (
|
||||
"You are using a decode-invariant form of batch invariance. "
|
||||
"This will not be invariant between prefill and decode."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
|
||||
# NCCL determinism settings
|
||||
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
|
||||
os.environ["NCCL_COLLNET_ENABLE"] = "0"
|
||||
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
|
||||
os.environ["NCCL_MIN_NCHANNELS"] = "1"
|
||||
os.environ["NCCL_MAX_NCHANNELS"] = "1"
|
||||
os.environ["NCCL_PROTO"] = "Simple"
|
||||
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
||||
os.environ["NCCL_NTHREADS"] = "1"
|
||||
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
|
||||
|
||||
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
"Forcibly updating attention backend to"
|
||||
f" {supported_backends[0]} for batch_invariant. "
|
||||
f" Supported backends: {supported_backends}."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
override_envs_for_invariance()
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
# Disable TF32 for batch invariance - it causes non-deterministic rounding
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
torch.backends.cudnn.allow_tf32 = False
|
||||
|
||||
@ -15,6 +15,9 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
@ -837,6 +840,10 @@ def get_moe_configs(
|
||||
be picked and the associated configuration chosen to invoke the kernel.
|
||||
"""
|
||||
|
||||
# Avoid optimizing for the batch invariant case. Use default config
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return None
|
||||
|
||||
# First look up if an optimized configuration is available in the configs
|
||||
# directory
|
||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||
@ -969,6 +976,15 @@ def get_default_config(
|
||||
dtype: str | None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> dict[str, int]:
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
return config
|
||||
|
||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||
@ -1118,7 +1134,10 @@ def fused_topk_bias(
|
||||
scores_for_choice = scores.view(
|
||||
-1, n_routed_experts
|
||||
) + e_score_correction_bias.unsqueeze(0)
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
topk_weights = scores.gather(1, topk_indices)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
@ -1179,7 +1198,10 @@ def grouped_topk(
|
||||
group_scores = (
|
||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
) # [n, n_group]
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||
|
||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||
use_sorted = vllm_kernel_override_batch_invariant()
|
||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||
1
|
||||
] # [n, top_k_group]
|
||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||
@ -1192,11 +1214,13 @@ def grouped_topk(
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_scores.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights, topk_ids = torch.topk(
|
||||
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||
)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
@ -8,6 +8,10 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
rms_norm_batch_invariant,
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@ -21,6 +25,8 @@ def rms_norm(
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||
out = torch.empty_like(x)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
@ -39,6 +45,10 @@ def fused_add_rms_norm(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
return rms_norm_batch_invariant(
|
||||
x + residual, weight, variance_epsilon
|
||||
), x + residual
|
||||
ops.fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
|
||||
@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
||||
k_pe,
|
||||
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
|
||||
)
|
||||
|
||||
return self.o_proj(attn_out)[0]
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
|
||||
@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
FusedMoEActivationFormat,
|
||||
@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||
|
||||
@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
# Dequantize FP8 weights to BF16
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
|
||||
# Handle different quantization granularities
|
||||
if self.block_quant:
|
||||
# Block-wise quantization:
|
||||
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
|
||||
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
|
||||
assert self.weight_block_size is not None
|
||||
block_n, block_k = self.weight_block_size # Note: order is [N, K]
|
||||
|
||||
N, K = weight_fp8.shape
|
||||
|
||||
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
|
||||
# We need to transpose it to [num_blocks_n, num_blocks_k] first
|
||||
weight_scale = weight_scale.t()
|
||||
|
||||
# Expand scale to match weight dimensions
|
||||
# scale_expanded should have shape [N, K]
|
||||
scale_expanded = weight_scale.repeat_interleave(
|
||||
block_n, dim=0
|
||||
).repeat_interleave(block_k, dim=1)
|
||||
# Trim to exact weight size (in case of padding)
|
||||
scale_expanded = scale_expanded[:N, :K]
|
||||
weight_bf16 = weight_fp8 * scale_expanded
|
||||
else:
|
||||
# Per-tensor quantization: weight IS transposed to [K, N]
|
||||
# scale should be scalar or [1] or per-output-channel [N]
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
else:
|
||||
# Multiple scales (fused modules like QKV)
|
||||
# Try to infer correct broadcasting
|
||||
# weight is [K, N], scale could be [num_logical_weights]
|
||||
# Need to figure out how to broadcast - for now just try
|
||||
# direct multiplication
|
||||
if (
|
||||
weight_scale.dim() == 1
|
||||
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||
):
|
||||
# Per-row scaling
|
||||
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
|
||||
# For block quant, weight is [N, K], for per-tensor it's [K, N]
|
||||
# F.linear expects weight to be [N, K], so:
|
||||
if self.block_quant:
|
||||
# Already in correct shape [N, K]
|
||||
output = torch.nn.functional.linear(x, weight_bf16, bias)
|
||||
else:
|
||||
# Need to transpose back: [K, N] -> [N, K]
|
||||
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
return output
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
|
||||
@ -216,6 +216,7 @@ class TransformerBlock(torch.nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.attn(hidden_states, positions)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
output = self.mlp(hidden_states)
|
||||
|
||||
@ -31,10 +31,12 @@ if is_flash_attn_varlen_func_available():
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
@ -306,6 +308,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
def schedule(
|
||||
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
||||
):
|
||||
@ -478,6 +483,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
"FlashAttention does not support fp8 kv-cache on this device."
|
||||
@ -810,6 +818,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
q_descale=layer._q_scale.expand(descale_shape),
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
num_splits=1 if self.batch_invariant_enabled else 0,
|
||||
)
|
||||
|
||||
return output
|
||||
@ -954,6 +963,7 @@ def cascade_attention(
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
s_aux=s_aux,
|
||||
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
@ -978,6 +988,7 @@ def cascade_attention(
|
||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
||||
@ -211,6 +211,9 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
@ -1187,6 +1190,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = aiter_triton_fp8_bmm(
|
||||
@ -1279,6 +1283,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# ROCm leverages the upstream flash_attn, which takes a parameter
|
||||
# called "return_attn_probs" instead of return_softmax_lse
|
||||
kwargs["return_attn_probs"] = return_softmax_lse
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
kwargs["num_splits"] = 1
|
||||
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
@ -1841,9 +1847,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
decode_q_nope, decode_q_pe = decode_q.split(
|
||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
|
||||
@ -1868,17 +1876,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
N, B, P = decode_q_nope.shape
|
||||
_, _, L = self.W_UK_T.shape
|
||||
|
||||
if self.q_pad_num_heads is not None:
|
||||
decode_ql_nope = decode_q_nope.new_empty(
|
||||
(self.q_pad_num_heads, B, L)
|
||||
)
|
||||
decode_ql_nope.resize_((N, B, L))
|
||||
|
||||
else:
|
||||
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
||||
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
||||
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
|
||||
@ -18,6 +18,9 @@ from vllm.attention.utils.fa_utils import (
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@ -107,6 +110,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
||||
):
|
||||
@ -175,7 +181,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
return FlashAttnMLADecodeMetadata(
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
query_start_loc=query_start_loc_device,
|
||||
@ -185,6 +194,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
|
||||
@ -14,6 +14,9 @@ from vllm.attention.ops.flashmla import (
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
@ -223,19 +226,50 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
# mypy assertion: q is now always a tensor
|
||||
assert isinstance(q, torch.Tensor)
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||
num_splits = attn_metadata.decode.num_splits
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
B = q.shape[0]
|
||||
# block_table shape: [batch_size, max_num_blocks_per_seq]
|
||||
# The number of blocks per sequence is in the second dimension
|
||||
topk = attn_metadata.decode.block_table.shape[-1]
|
||||
B_TOPK = 64
|
||||
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
|
||||
end_block_idx = topk // B_TOPK
|
||||
|
||||
# Single partition => num_sm_parts = 1
|
||||
# TileSchedulerMetaDataSize = 8, layout:
|
||||
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
|
||||
# begin_n_split_idx, _, _, _]
|
||||
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
|
||||
tile_scheduler_metadata[0, 0] = 0 # begin_idx
|
||||
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
|
||||
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
|
||||
tile_scheduler_metadata[0, 3] = end_block_idx
|
||||
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
|
||||
# fields [5..7] stay 0
|
||||
|
||||
# Non-split path ignores num_splits, but the API requires it:
|
||||
# zeros of length B+1
|
||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
|
||||
@ -13,6 +13,9 @@ from vllm.attention.backends.abstract import (
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_kernel_override_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
@ -158,7 +161,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
||||
)
|
||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||
num_kv_splits = 4 # TODO: heuristic
|
||||
|
||||
# For batch invariance, use only 1 split to ensure deterministic reduction
|
||||
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
|
||||
@ -231,9 +231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||
|
||||
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
|
||||
|
||||
init_batch_invariance()
|
||||
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
@ -764,6 +764,9 @@ def init_worker_distributed_environment(
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = vllm_config.parallel_config
|
||||
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
|
||||
|
||||
init_batch_invariance()
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user