mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:05:48 +08:00
Deepseek-v3 Batch Invariant on 8xH100 (#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
parent
785d8b6410
commit
7d8975de84
@ -3,56 +3,73 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import string
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def enable_batch_invariant_mode():
|
||||||
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||||
|
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||||
|
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
|
||||||
|
yield
|
||||||
|
# Restore original value after test
|
||||||
|
if old_value is None:
|
||||||
|
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||||
|
|
||||||
|
|
||||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
# Lightweight random prompt generator to vary prompt lengths and content.
|
# Generate more realistic prompts that will actually produce varied tokens
|
||||||
vocab = [
|
# Use a mix of common English text patterns
|
||||||
"alpha",
|
|
||||||
"bravo",
|
prompt_templates = [
|
||||||
"charlie",
|
# Question-answer style
|
||||||
"delta",
|
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||||
"echo",
|
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||||
"foxtrot",
|
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||||
"golf",
|
# Story/narrative style
|
||||||
"hotel",
|
"Once upon a time in a distant galaxy, there lived",
|
||||||
"india",
|
"The old man walked slowly down the street, remembering",
|
||||||
"juliet",
|
"In the year 2157, humanity finally discovered",
|
||||||
"kilo",
|
# Technical/code style
|
||||||
"lima",
|
"To implement a binary search tree in Python, first we need to",
|
||||||
"mike",
|
"The algorithm works by iterating through the array and",
|
||||||
"november",
|
"Here's how to optimize database queries using indexing:",
|
||||||
"oscar",
|
# Factual/informative style
|
||||||
"papa",
|
"The Renaissance was a period in European history that",
|
||||||
"quebec",
|
"Climate change is caused by several factors including",
|
||||||
"romeo",
|
"The human brain contains approximately 86 billion neurons which",
|
||||||
"sierra",
|
# Conversational style
|
||||||
"tango",
|
"I've been thinking about getting a new laptop because",
|
||||||
"uniform",
|
"Yesterday I went to the store and bought",
|
||||||
"victor",
|
"My favorite thing about summer is definitely",
|
||||||
"whiskey",
|
|
||||||
"xray",
|
|
||||||
"yankee",
|
|
||||||
"zulu",
|
|
||||||
]
|
]
|
||||||
n = random.randint(min_words, max_words)
|
|
||||||
words = random.choices(vocab, k=n)
|
|
||||||
|
|
||||||
# Add some noise and punctuation variability
|
# Pick a random template
|
||||||
if random.random() < 0.5:
|
base_prompt = random.choice(prompt_templates)
|
||||||
words[0] = words[0].capitalize()
|
|
||||||
if random.random() < 0.2:
|
# Add some padding to vary the length if needed
|
||||||
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
|
if min_words > 50:
|
||||||
punct = random.choice([".", "?", "!", "...", ""])
|
# For longer prompts, repeat context
|
||||||
return " ".join(words) + punct
|
padding_text = (
|
||||||
|
" This is an interesting topic that deserves more explanation. "
|
||||||
|
* (min_words // 50)
|
||||||
|
)
|
||||||
|
base_prompt = base_prompt + padding_text
|
||||||
|
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
@pytest.mark.timeout(1000)
|
@pytest.mark.timeout(1000)
|
||||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||||
"""
|
"""
|
||||||
@ -91,7 +108,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
|
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4"))
|
||||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
|
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120"))
|
||||||
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
|
||||||
|
|
||||||
# Sampling parameters: longer outputs with a more random-sounding
|
# Sampling parameters: longer outputs with a more random-sounding
|
||||||
# continuation,but still deterministic due to fixed seed.
|
# continuation,but still deterministic due to fixed seed.
|
||||||
@ -117,7 +133,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
max_num_seqs=max_batch_size,
|
max_num_seqs=max_batch_size,
|
||||||
gpu_memory_utilization=gpu_mem_util,
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
swap_space=swap_space_gb,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Baseline generation for the needle prompt alone.
|
# Baseline generation for the needle prompt alone.
|
||||||
@ -132,7 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
|||||||
max_num_seqs=max_batch_size,
|
max_num_seqs=max_batch_size,
|
||||||
gpu_memory_utilization=gpu_mem_util,
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
swap_space=swap_space_gb,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mismatches = 0
|
mismatches = 0
|
||||||
@ -195,16 +209,21 @@ def _extract_step_logprobs(request_output):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
return t
|
return t, inner.token_ids
|
||||||
|
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not torch.cuda.is_available(),
|
not torch.cuda.is_available(),
|
||||||
reason="Requires CUDA to match production inference path.",
|
reason="Requires CUDA to match production inference path.",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||||
|
@pytest.mark.forked
|
||||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
||||||
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
@ -214,69 +233,754 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
|||||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
# Force float32 to avoid precision-induced differences.
|
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||||
|
# all-reduce operations (custom all-reduce may not be deterministic)
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
|
|
||||||
|
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||||
|
|
||||||
|
if disable_custom_ar:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
enforce_eager=True,
|
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
|
max_num_seqs=32,
|
||||||
|
max_model_len=8192,
|
||||||
|
dtype="bfloat16", # not everything is supported
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts = [_random_prompt(10, 1024) for i in range(100)]
|
# Use more realistic prompts for better token generation
|
||||||
|
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||||
|
|
||||||
sp = SamplingParams(
|
sp = SamplingParams(
|
||||||
temperature=0.6,
|
temperature=0.6,
|
||||||
top_p=1.0,
|
top_p=1.0,
|
||||||
max_tokens=8,
|
max_tokens=8,
|
||||||
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
|
||||||
seed=1234,
|
seed=1234,
|
||||||
logprobs=5,
|
logprobs=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
# BS=1: run prompts individually and collect logprobs per step.
|
# BS=1: run prompts individually and collect logprobs per step.
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
bs1_logprobs_per_prompt = []
|
bs1_logprobs_per_prompt = []
|
||||||
for p in prompts:
|
bs1_tokens_per_prompt = []
|
||||||
|
for idx, p in enumerate(prompts):
|
||||||
|
print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...")
|
||||||
outs = llm.generate([p], sp, use_tqdm=False)
|
outs = llm.generate([p], sp, use_tqdm=False)
|
||||||
assert len(outs) == 1
|
assert len(outs) == 1
|
||||||
step_logprobs = _extract_step_logprobs(outs[0])
|
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||||
if step_logprobs is None:
|
if step_logprobs is None:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Logits are not available on RequestOutput; "
|
"Logits are not available on RequestOutput; "
|
||||||
"enable logprobs return to run this test."
|
"enable logprobs return to run this test."
|
||||||
)
|
)
|
||||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
bs1_tokens_per_prompt.append(token_ids)
|
||||||
|
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||||
|
|
||||||
# BS=N: run prompts in a batch and collect logprobs per step for each
|
# BS=N: run prompts in a batch and collect logprobs per step for each
|
||||||
# prompt.
|
# prompt.
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||||
assert len(outs_batched) == len(prompts)
|
assert len(outs_batched) == len(prompts)
|
||||||
bsN_logprobs_per_prompt = []
|
bsN_logprobs_per_prompt = []
|
||||||
for o in outs_batched:
|
bsN_tokens_per_prompt = []
|
||||||
step_logprobs = _extract_step_logprobs(o)
|
|
||||||
|
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||||
|
for idx, o in enumerate(outs_batched):
|
||||||
|
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||||
|
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||||
|
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||||
if step_logprobs is None:
|
if step_logprobs is None:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Logits are not available on RequestOutput; "
|
"Logits are not available on RequestOutput; "
|
||||||
"enable logprobs return to run this test."
|
"enable logprobs return to run this test."
|
||||||
)
|
)
|
||||||
bsN_logprobs_per_prompt.append(step_logprobs)
|
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
bsN_tokens_per_prompt.append(token_ids)
|
||||||
|
|
||||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||||
for i, (logprobs_bs1, logprobs_bsN) in enumerate(
|
failed_prompts = []
|
||||||
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)
|
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||||
):
|
zip(
|
||||||
assert len(logprobs_bs1) == len(logprobs_bsN), (
|
bs1_logprobs_per_prompt,
|
||||||
f"Different number of generation steps for prompt index {i}: "
|
bsN_logprobs_per_prompt,
|
||||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)"
|
bs1_tokens_per_prompt,
|
||||||
|
bsN_tokens_per_prompt,
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||||
|
reason = (
|
||||||
|
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||||
|
f"vs {len(logprobs_bsN)} (BS=N)"
|
||||||
|
)
|
||||||
|
failed_prompts.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": "all",
|
||||||
|
"reason": reason,
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if tokens match first
|
||||||
|
if tokens_bs1 != tokens_bsN:
|
||||||
|
failed_prompts.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": "sampling",
|
||||||
|
"reason": "Different tokens sampled",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
"bs1_all_logprobs": [
|
||||||
|
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
|
||||||
|
],
|
||||||
|
"bsN_all_logprobs": [
|
||||||
|
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||||
assert a.shape == b.shape, (
|
if a.shape != b.shape:
|
||||||
f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}"
|
failed_prompts.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": t,
|
||||||
|
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not torch.equal(a, b):
|
||||||
|
max_diff = torch.abs(a - b).max().item()
|
||||||
|
# Print which token failed
|
||||||
|
print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}")
|
||||||
|
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||||
|
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||||
|
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||||
|
print(f" BS=1 logprob: {a.tolist()}")
|
||||||
|
print(f" BS=N logprob: {b.tolist()}")
|
||||||
|
failed_prompts.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": t,
|
||||||
|
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
"bs1_all_logprobs": [
|
||||||
|
logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1))
|
||||||
|
],
|
||||||
|
"bsN_all_logprobs": [
|
||||||
|
logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN))
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Print summary of all failures
|
||||||
|
if failed_prompts:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
fail_msg = (
|
||||||
|
f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/"
|
||||||
|
f"{len(prompts)} prompts failed"
|
||||||
|
)
|
||||||
|
print(fail_msg)
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for fail in failed_prompts:
|
||||||
|
print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):")
|
||||||
|
print(f" Reason: {fail['reason']}")
|
||||||
|
print(f" Preview: {fail['prompt_preview']}...")
|
||||||
|
|
||||||
|
# Always show the tokens
|
||||||
|
if "bs1_tokens" in fail:
|
||||||
|
print(f" BS=1 tokens: {fail['bs1_tokens']}")
|
||||||
|
if "bsN_tokens" in fail:
|
||||||
|
print(f" BS=N tokens: {fail['bsN_tokens']}")
|
||||||
|
|
||||||
|
if "bs1_all_logprobs" in fail:
|
||||||
|
print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:")
|
||||||
|
for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]):
|
||||||
|
print(f" Step {step_idx}: {logprobs}")
|
||||||
|
print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:")
|
||||||
|
for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]):
|
||||||
|
print(f" Step {step_idx}: {logprobs}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
# Fail the test with summary
|
||||||
|
msg = (
|
||||||
|
f"Batch invariance violated in {len(failed_prompts)}/"
|
||||||
|
f"{len(prompts)} prompts. See output above for details."
|
||||||
|
)
|
||||||
|
pytest.fail(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
def test_simple_generation():
|
||||||
|
"""
|
||||||
|
Simple test that runs the model with a basic prompt and prints the output.
|
||||||
|
Useful for quick smoke testing and debugging.
|
||||||
|
"""
|
||||||
|
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
max_num_seqs=1,
|
||||||
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
max_model_len=2048,
|
||||||
|
dtype="bfloat16",
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "the capital of france is"
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("Running simple generation test")
|
||||||
|
print(f"Prompt: '{prompt}'")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = llm.generate([prompt], sampling_params)
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
output_text = outputs[0].outputs[0].text
|
||||||
|
|
||||||
|
print(f"Output: '{output_text}'")
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"Full completion: '{prompt}{output_text}'")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
llm.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(),
|
||||||
|
reason="Requires CUDA to match production inference path.",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
||||||
|
"""
|
||||||
|
This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN.
|
||||||
|
It DISABLES batch invariance mode and expects to see non-deterministic behavior
|
||||||
|
between BS=1 and BS=N runs. This demonstrates that batch invariance is actually
|
||||||
|
doing something useful.
|
||||||
|
|
||||||
|
The test will PASS if we detect differences (proving batch invariance matters).
|
||||||
|
The test will FAIL if everything matches (suggesting batch invariance isn't needed).
|
||||||
|
"""
|
||||||
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
|
||||||
|
# CRITICAL: Disable batch invariance for this test
|
||||||
|
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
||||||
|
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
|
||||||
|
|
||||||
|
try:
|
||||||
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
|
random.seed(seed)
|
||||||
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
max_num_seqs=32,
|
||||||
|
max_model_len=8192,
|
||||||
|
dtype="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use more realistic prompts for better token generation
|
||||||
|
prompts = [_random_prompt(10, 50) for i in range(32)]
|
||||||
|
|
||||||
|
sp = SamplingParams(
|
||||||
|
temperature=0.6,
|
||||||
|
top_p=1.0,
|
||||||
|
max_tokens=8,
|
||||||
|
seed=1234,
|
||||||
|
logprobs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# BS=1: run prompts individually and collect logprobs per step.
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STARTING BS=1 RUNS (each prompt individually)")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
bs1_logprobs_per_prompt = []
|
||||||
|
bs1_tokens_per_prompt = []
|
||||||
|
for idx, p in enumerate(prompts):
|
||||||
|
print(
|
||||||
|
f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..."
|
||||||
)
|
)
|
||||||
# Bitwise exact equality.
|
outs = llm.generate([p], sp, use_tqdm=False)
|
||||||
assert torch.equal(a, b), (
|
assert len(outs) == 1
|
||||||
f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
step_logprobs, token_ids = _extract_step_logprobs(outs[0])
|
||||||
f"(dtype={a.dtype}, shape={a.shape})."
|
if step_logprobs is None:
|
||||||
|
pytest.skip(
|
||||||
|
"Logits are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test."
|
||||||
|
)
|
||||||
|
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
bs1_tokens_per_prompt.append(token_ids)
|
||||||
|
print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}")
|
||||||
|
|
||||||
|
# BS=N: run prompts in a batch and collect logprobs per step for each prompt.
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"STARTING BS={len(prompts)} RUN (all prompts batched)")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||||
|
assert len(outs_batched) == len(prompts)
|
||||||
|
bsN_logprobs_per_prompt = []
|
||||||
|
bsN_tokens_per_prompt = []
|
||||||
|
|
||||||
|
print(f"\n[BS={len(prompts)}] Processing batched outputs...")
|
||||||
|
for idx, o in enumerate(outs_batched):
|
||||||
|
tokens = o.outputs[0].token_ids if o.outputs else "N/A"
|
||||||
|
print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}")
|
||||||
|
step_logprobs, token_ids = _extract_step_logprobs(o)
|
||||||
|
if step_logprobs is None:
|
||||||
|
pytest.skip(
|
||||||
|
"Logits are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test."
|
||||||
|
)
|
||||||
|
bsN_logprobs_per_prompt.append(step_logprobs)
|
||||||
|
bsN_tokens_per_prompt.append(token_ids)
|
||||||
|
|
||||||
|
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs.
|
||||||
|
differences_found = []
|
||||||
|
for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate(
|
||||||
|
zip(
|
||||||
|
bs1_logprobs_per_prompt,
|
||||||
|
bsN_logprobs_per_prompt,
|
||||||
|
bs1_tokens_per_prompt,
|
||||||
|
bsN_tokens_per_prompt,
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||||
|
reason = (
|
||||||
|
f"Different number of steps: {len(logprobs_bs1)} (BS=1) "
|
||||||
|
f"vs {len(logprobs_bsN)} (BS=N)"
|
||||||
|
)
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": "all",
|
||||||
|
"reason": reason,
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if tokens match first
|
||||||
|
if tokens_bs1 != tokens_bsN:
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": "sampling",
|
||||||
|
"reason": "Different tokens sampled",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||||
|
if a.shape != b.shape:
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": t,
|
||||||
|
"reason": f"Shape mismatch: {a.shape} vs {b.shape}",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not torch.equal(a, b):
|
||||||
|
max_diff = torch.abs(a - b).max().item()
|
||||||
|
print(
|
||||||
|
f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, "
|
||||||
|
f"Token {t}: max_diff={max_diff:.6e}"
|
||||||
|
)
|
||||||
|
bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A"
|
||||||
|
bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A"
|
||||||
|
print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}")
|
||||||
|
print(f" BS=1 logprob: {a.tolist()}")
|
||||||
|
print(f" BS=N logprob: {b.tolist()}")
|
||||||
|
differences_found.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": i,
|
||||||
|
"step": t,
|
||||||
|
"reason": f"Bitwise mismatch (max_diff={max_diff:.6e})",
|
||||||
|
"prompt_preview": prompts[i][:100],
|
||||||
|
"bs1_tokens": tokens_bs1,
|
||||||
|
"bsN_tokens": tokens_bsN,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
if differences_found:
|
||||||
|
success_msg = (
|
||||||
|
f"✓ SUCCESS: Batch invariance is doing something! "
|
||||||
|
f"Found {len(differences_found)}/{len(prompts)} prompts "
|
||||||
|
f"with differences when batch invariance was DISABLED."
|
||||||
|
)
|
||||||
|
print(success_msg)
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
for diff in differences_found:
|
||||||
|
print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):")
|
||||||
|
print(f" Reason: {diff['reason']}")
|
||||||
|
print(f" Preview: {diff['prompt_preview']}...")
|
||||||
|
if "bs1_tokens" in diff:
|
||||||
|
print(f" BS=1 tokens: {diff['bs1_tokens']}")
|
||||||
|
if "bsN_tokens" in diff:
|
||||||
|
print(f" BS=N tokens: {diff['bsN_tokens']}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
# Test PASSES because we found differences (batch invariance matters!)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Test FAILS because everything matched even without batch invariance
|
||||||
|
fail_msg = (
|
||||||
|
f"✗ UNEXPECTED: All {len(prompts)} prompts matched "
|
||||||
|
f"between BS=1 and BS=N even with batch invariance DISABLED. "
|
||||||
|
f"This suggests batch invariance might not be necessary, "
|
||||||
|
f"or the test needs more sensitive prompts."
|
||||||
|
)
|
||||||
|
print(fail_msg)
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
pytest.fail(fail_msg)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original value
|
||||||
|
if old_value is None:
|
||||||
|
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(),
|
||||||
|
reason="Requires CUDA to match production inference path.",
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||||
|
@pytest.mark.forked
|
||||||
|
def test_decode_logprobs_match_prefill_logprobs(backend):
|
||||||
|
"""
|
||||||
|
Test that verifies decode logprobs match prefill logprobs.
|
||||||
|
|
||||||
|
For each decoded token at position i:
|
||||||
|
1. Run decode to generate N tokens and collect their logprobs
|
||||||
|
2. For each position i in [0, N):
|
||||||
|
- Take prefix = prompt + tokens[0:i]
|
||||||
|
- Run prefill(prefix + tokens[i]) to get logprob of tokens[i]
|
||||||
|
- Verify prefill logprob matches decode logprob bitwise
|
||||||
|
|
||||||
|
This ensures that the logprobs from decode are consistent with what
|
||||||
|
we would get if we ran prefill on each prefix.
|
||||||
|
"""
|
||||||
|
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
|
||||||
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
|
random.seed(seed)
|
||||||
|
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
|
|
||||||
|
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
||||||
|
|
||||||
|
if disable_custom_ar:
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
max_num_seqs=32,
|
||||||
|
max_model_len=8192,
|
||||||
|
dtype="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a few test prompts
|
||||||
|
num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4"))
|
||||||
|
prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)]
|
||||||
|
|
||||||
|
# Generate longer sequences to test multiple decode steps
|
||||||
|
max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16"))
|
||||||
|
|
||||||
|
sp = SamplingParams(
|
||||||
|
temperature=0.0, # Greedy for determinism
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
logprobs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("STEP 1: Running decode to generate tokens and collect logprobs")
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
# Step 1: Run decode and collect logprobs
|
||||||
|
decode_outputs = llm.generate(prompts, sp, use_tqdm=False)
|
||||||
|
|
||||||
|
failed_comparisons = []
|
||||||
|
|
||||||
|
for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)):
|
||||||
|
print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...")
|
||||||
|
|
||||||
|
# Extract decode logprobs and tokens
|
||||||
|
decode_logprobs, token_ids = _extract_step_logprobs(decode_output)
|
||||||
|
if decode_logprobs is None:
|
||||||
|
pytest.skip(
|
||||||
|
"Logprobs are not available on RequestOutput; "
|
||||||
|
"enable logprobs return to run this test."
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}")
|
||||||
|
print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}")
|
||||||
|
|
||||||
|
# Step 2: For each token position, run prefill and compare
|
||||||
|
print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...")
|
||||||
|
|
||||||
|
for token_idx in range(len(token_ids)):
|
||||||
|
# Construct the prefix up to (but not including) this token
|
||||||
|
current_token = token_ids[token_idx]
|
||||||
|
|
||||||
|
# We need to detokenize to get the text prefix
|
||||||
|
# For this, we'll use the tokenizer from the LLM
|
||||||
|
# However, the LLM API doesn't expose tokenizer easily, so we'll
|
||||||
|
# construct the prefix by decoding from the original prompt
|
||||||
|
|
||||||
|
# Get text up to this point by using the output text
|
||||||
|
# This is approximate but should work for verification
|
||||||
|
if token_idx == 0:
|
||||||
|
prefix_prompt = prompt
|
||||||
|
else:
|
||||||
|
# Use the partial output text up to this token
|
||||||
|
# We'll need to construct this from the full output
|
||||||
|
prefix_output = decode_output.outputs[0]
|
||||||
|
# Get the text for tokens 0 to token_idx-1
|
||||||
|
# Unfortunately, we don't have per-token text, so we'll use
|
||||||
|
# a different approach: run prefill with prompt + tokens[0:token_idx]
|
||||||
|
|
||||||
|
# Actually, we need to get the actual text. Let's use a workaround:
|
||||||
|
# Run a generation with max_tokens = token_idx to get that prefix
|
||||||
|
prefix_sp = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=token_idx,
|
||||||
|
logprobs=1,
|
||||||
|
)
|
||||||
|
prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0]
|
||||||
|
prefix_prompt = prompt + prefix_output.outputs[0].text
|
||||||
|
|
||||||
|
# Now run prefill with max_tokens=1 to get the logprob of the next token
|
||||||
|
prefill_sp = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=1,
|
||||||
|
logprobs=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" [Token {token_idx}] Running prefill for prefix "
|
||||||
|
f"(len={len(prefix_prompt)})..."
|
||||||
|
)
|
||||||
|
prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output)
|
||||||
|
|
||||||
|
if prefill_logprobs is None:
|
||||||
|
print(f" [Token {token_idx}] Warning: No prefill logprobs available")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The first token from prefill should match the current token
|
||||||
|
prefill_token = prefill_token_ids[0]
|
||||||
|
prefill_logprob = prefill_logprobs[0].item()
|
||||||
|
decode_logprob = decode_logprobs[token_idx].item()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" [Token {token_idx}] Decode token: {current_token}, "
|
||||||
|
f"logprob: {decode_logprob:.8f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" [Token {token_idx}] Prefill token: {prefill_token}, "
|
||||||
|
f"logprob: {prefill_logprob:.8f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if tokens match
|
||||||
|
if current_token != prefill_token:
|
||||||
|
failed_comparisons.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": prompt_idx,
|
||||||
|
"token_idx": token_idx,
|
||||||
|
"reason": "Token mismatch",
|
||||||
|
"decode_token": current_token,
|
||||||
|
"prefill_token": prefill_token,
|
||||||
|
"decode_logprob": decode_logprob,
|
||||||
|
"prefill_logprob": prefill_logprob,
|
||||||
|
"prompt_text": prompt[:100],
|
||||||
|
"prefix_text": prefix_prompt[:100],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if logprobs match bitwise
|
||||||
|
if decode_logprob != prefill_logprob:
|
||||||
|
diff = abs(decode_logprob - prefill_logprob)
|
||||||
|
failed_comparisons.append(
|
||||||
|
{
|
||||||
|
"prompt_idx": prompt_idx,
|
||||||
|
"token_idx": token_idx,
|
||||||
|
"reason": "Logprob mismatch",
|
||||||
|
"decode_token": current_token,
|
||||||
|
"prefill_token": prefill_token,
|
||||||
|
"decode_logprob": decode_logprob,
|
||||||
|
"prefill_logprob": prefill_logprob,
|
||||||
|
"diff": diff,
|
||||||
|
"prompt_text": prompt[:100],
|
||||||
|
"prefix_text": prefix_prompt[:100],
|
||||||
|
"decode_all_tokens": token_ids,
|
||||||
|
"decode_all_logprobs": decode_logprobs.tolist(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}")
|
||||||
|
else:
|
||||||
|
print(f" [Token {token_idx}] ✓ Match (bitwise equal)")
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
if failed_comparisons:
|
||||||
|
print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
# Group failures by prompt for better readability
|
||||||
|
failures_by_prompt: dict[int, list[dict]] = {}
|
||||||
|
for fail in failed_comparisons:
|
||||||
|
pid = fail["prompt_idx"]
|
||||||
|
if pid not in failures_by_prompt:
|
||||||
|
failures_by_prompt[pid] = []
|
||||||
|
failures_by_prompt[pid].append(fail)
|
||||||
|
|
||||||
|
for prompt_idx, failures in failures_by_prompt.items():
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"Total failures for this prompt: {len(failures)}")
|
||||||
|
|
||||||
|
# Show where mismatches occur (which token positions)
|
||||||
|
mismatch_positions = [f["token_idx"] for f in failures]
|
||||||
|
print(f"Mismatch at token positions: {mismatch_positions}")
|
||||||
|
|
||||||
|
# Show first few failures in detail
|
||||||
|
for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt
|
||||||
|
print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:")
|
||||||
|
print(f" Reason: {fail['reason']}")
|
||||||
|
print(f" Prefix text: '{fail['prefix_text']}...'")
|
||||||
|
print(
|
||||||
|
f" Decode: token={fail['decode_token']}, "
|
||||||
|
f"logprob={fail['decode_logprob']:.10f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Prefill: token={fail['prefill_token']}, "
|
||||||
|
f"logprob={fail['prefill_logprob']:.10f}"
|
||||||
|
)
|
||||||
|
if "diff" in fail:
|
||||||
|
print(f" Difference: {fail['diff']:.10e}")
|
||||||
|
# Show in hex to see bitwise difference
|
||||||
|
import struct
|
||||||
|
|
||||||
|
decode_hex = struct.pack("f", fail["decode_logprob"]).hex()
|
||||||
|
prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex()
|
||||||
|
print(f" Decode logprob (hex): 0x{decode_hex}")
|
||||||
|
print(f" Prefill logprob (hex): 0x{prefill_hex}")
|
||||||
|
|
||||||
|
# If we have all tokens/logprobs, show the context
|
||||||
|
if "decode_all_tokens" in fail and "decode_all_logprobs" in fail:
|
||||||
|
token_idx = fail["token_idx"]
|
||||||
|
all_tokens = fail["decode_all_tokens"]
|
||||||
|
all_logprobs = fail["decode_all_logprobs"]
|
||||||
|
|
||||||
|
# Show context: 2 tokens before and after
|
||||||
|
start = max(0, token_idx - 2)
|
||||||
|
end = min(len(all_tokens), token_idx + 3)
|
||||||
|
|
||||||
|
print(f" Context (tokens {start} to {end - 1}):")
|
||||||
|
for j in range(start, end):
|
||||||
|
marker = " <-- MISMATCH" if j == token_idx else ""
|
||||||
|
print(
|
||||||
|
f" [{j}] token={all_tokens[j]}, "
|
||||||
|
f"logprob={all_logprobs[j]:.8f}{marker}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(failures) > 5:
|
||||||
|
print(f"\n ... and {len(failures) - 5} more failures for this prompt")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}\n")
|
||||||
|
|
||||||
|
pytest.fail(
|
||||||
|
f"Decode logprobs do not match prefill logprobs: "
|
||||||
|
f"{len(failed_comparisons)} mismatches found."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
|
||||||
def LLM_with_max_seqs(
|
def LLM_with_max_seqs(
|
||||||
@ -284,7 +988,6 @@ def LLM_with_max_seqs(
|
|||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
gpu_memory_utilization: float,
|
gpu_memory_utilization: float,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
swap_space: int,
|
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""
|
"""
|
||||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||||
@ -293,17 +996,10 @@ def LLM_with_max_seqs(
|
|||||||
return LLM(
|
return LLM(
|
||||||
model=model,
|
model=model,
|
||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
# Constrain GPU memory pool so test can run even on busy GPUs.
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
# Keep KV cache footprint small while allowing longer outputs.
|
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
# Allow some CPU offload if needed.
|
dtype="bfloat16",
|
||||||
swap_space=swap_space,
|
|
||||||
# Keep things lean and CI-friendly.
|
|
||||||
dtype="auto",
|
|
||||||
# Single-GPU by default; override externally if desired.
|
|
||||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||||
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
|
||||||
enable_prefix_caching=False,
|
enable_prefix_caching=False,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
# Enable for MOE models
|
# Enable for MOE models
|
||||||
|
|||||||
346
tests/v1/generation/test_rms_norm_batch_invariant.py
Normal file
346
tests/v1/generation/test_rms_norm_batch_invariant.py
Normal file
@ -0,0 +1,346 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Test batch-invariant RMS normalization against standard implementations.
|
||||||
|
|
||||||
|
This test compares the Triton-based batch-invariant RMS norm implementation
|
||||||
|
with the standard CUDA-based implementation to ensure numerical accuracy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
|
||||||
|
def test_rms_norm_batch_invariant_vs_standard(
|
||||||
|
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
|
||||||
|
|
||||||
|
Tests that the Triton-based batch-invariant RMS norm produces numerically
|
||||||
|
equivalent results to the standard CUDA implementation across various
|
||||||
|
configurations.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# Create test input and weight
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Standard implementation (CUDA ops)
|
||||||
|
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||||
|
rms_norm_layer.weight.data = weight.clone()
|
||||||
|
|
||||||
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
|
# Batch-invariant implementation (Triton)
|
||||||
|
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
|
# Compare outputs
|
||||||
|
# Use looser tolerance for bfloat16 due to its lower precision
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
|
||||||
|
else:
|
||||||
|
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_output,
|
||||||
|
standard_output,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg=f"RMS norm mismatch for batch_size={batch_size}, "
|
||||||
|
f"hidden_size={hidden_size}, "
|
||||||
|
f"dtype={dtype}, eps={eps}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 16, 128])
|
||||||
|
@pytest.mark.parametrize("seq_len", [1, 32, 512])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [2048, 4096])
|
||||||
|
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
|
||||||
|
"""
|
||||||
|
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
|
||||||
|
|
||||||
|
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
|
||||||
|
inputs that are common in transformer models.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_tensor = torch.randn(
|
||||||
|
batch_size, seq_len, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Standard implementation
|
||||||
|
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||||
|
rms_norm_layer.weight.data = weight.clone()
|
||||||
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
|
# Batch-invariant implementation
|
||||||
|
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
|
# Use looser tolerance for bfloat16
|
||||||
|
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_output,
|
||||||
|
standard_output,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
|
||||||
|
f"seq_len={seq_len}, hidden_size={hidden_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||||
|
)
|
||||||
|
def test_rms_norm_numerical_stability():
|
||||||
|
"""
|
||||||
|
Test RMS norm numerical stability with extreme values.
|
||||||
|
|
||||||
|
Ensures that both implementations handle edge cases like very small or large
|
||||||
|
values without producing NaN or Inf.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float16
|
||||||
|
eps = 1e-6
|
||||||
|
hidden_size = 2048
|
||||||
|
|
||||||
|
# Test cases with extreme values
|
||||||
|
test_cases = [
|
||||||
|
# Very small values
|
||||||
|
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
|
||||||
|
# Very large values
|
||||||
|
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
|
||||||
|
# Mixed small and large
|
||||||
|
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
|
||||||
|
# Values near zero
|
||||||
|
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
|
||||||
|
]
|
||||||
|
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
for idx, input_tensor in enumerate(test_cases):
|
||||||
|
# Standard implementation
|
||||||
|
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||||
|
rms_norm_layer.weight.data = weight.clone()
|
||||||
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
|
# Batch-invariant implementation
|
||||||
|
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
|
# Check for NaN or Inf
|
||||||
|
assert not torch.isnan(standard_output).any(), (
|
||||||
|
f"Standard RMS norm produced NaN for test case {idx}"
|
||||||
|
)
|
||||||
|
assert not torch.isinf(standard_output).any(), (
|
||||||
|
f"Standard RMS norm produced Inf for test case {idx}"
|
||||||
|
)
|
||||||
|
assert not torch.isnan(triton_output).any(), (
|
||||||
|
f"Triton RMS norm produced NaN for test case {idx}"
|
||||||
|
)
|
||||||
|
assert not torch.isinf(triton_output).any(), (
|
||||||
|
f"Triton RMS norm produced Inf for test case {idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare outputs - very lenient for extreme values with float16
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_output,
|
||||||
|
standard_output,
|
||||||
|
rtol=2e-1, # 20% tolerance for extreme values
|
||||||
|
atol=2e-1,
|
||||||
|
msg=f"RMS norm mismatch for extreme value test case {idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||||
|
)
|
||||||
|
def test_rms_norm_formula():
|
||||||
|
"""
|
||||||
|
Test that RMS norm follows the correct mathematical formula.
|
||||||
|
|
||||||
|
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.float32 # Use float32 for higher precision in formula check
|
||||||
|
eps = 1e-6
|
||||||
|
hidden_size = 1024
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Compute expected output using the formula
|
||||||
|
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
|
||||||
|
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
|
||||||
|
|
||||||
|
# Batch-invariant implementation
|
||||||
|
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
|
# Compare against formula
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_output,
|
||||||
|
expected_output,
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=1e-4,
|
||||||
|
msg="Triton RMS norm doesn't match expected formula",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
|
||||||
|
def test_rms_norm_different_hidden_sizes(hidden_size: int):
|
||||||
|
"""
|
||||||
|
Test RMS norm with various hidden sizes to ensure block size handling.
|
||||||
|
|
||||||
|
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
|
||||||
|
correctly handles hidden sizes both smaller and larger than the block size.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
eps = 1e-6
|
||||||
|
batch_size = 16
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Standard implementation
|
||||||
|
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||||
|
rms_norm_layer.weight.data = weight.clone()
|
||||||
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
|
# Batch-invariant implementation
|
||||||
|
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
|
# Use looser tolerance for bfloat16
|
||||||
|
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
triton_output,
|
||||||
|
standard_output,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.has_device_capability(90),
|
||||||
|
reason="Batch invariance tests only supported on Hopper (SM90)",
|
||||||
|
)
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
|
||||||
|
)
|
||||||
|
def test_rms_norm_determinism():
|
||||||
|
"""
|
||||||
|
Test that batch-invariant RMS norm produces deterministic results.
|
||||||
|
|
||||||
|
Runs the same input through the kernel multiple times and verifies
|
||||||
|
identical outputs.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
eps = 1e-6
|
||||||
|
hidden_size = 4096
|
||||||
|
batch_size = 32
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Run multiple times
|
||||||
|
outputs = []
|
||||||
|
for _ in range(5):
|
||||||
|
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
# All outputs should be identical
|
||||||
|
reference = outputs[0]
|
||||||
|
for idx, output in enumerate(outputs[1:], start=1):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
output,
|
||||||
|
reference,
|
||||||
|
rtol=0.0,
|
||||||
|
atol=0.0,
|
||||||
|
msg=f"RMS norm not deterministic: run {idx} differs from reference",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run a quick smoke test
|
||||||
|
print("Running quick smoke test of RMS norm implementations...")
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
batch_size = 8
|
||||||
|
hidden_size = 4096
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Standard implementation
|
||||||
|
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
|
||||||
|
rms_norm_layer.weight.data = weight.clone()
|
||||||
|
standard_output = rms_norm_layer.forward_cuda(input_tensor)
|
||||||
|
|
||||||
|
# Batch-invariant implementation
|
||||||
|
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
max_diff = (triton_output - standard_output).abs().max().item()
|
||||||
|
mean_diff = (triton_output - standard_output).abs().mean().item()
|
||||||
|
|
||||||
|
print(f"Max difference: {max_diff:.6e}")
|
||||||
|
print(f"Mean difference: {mean_diff:.6e}")
|
||||||
|
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
|
||||||
|
print(f"Triton output sample: {triton_output[0, :5].tolist()}")
|
||||||
|
|
||||||
|
if max_diff < 1e-3:
|
||||||
|
print("✓ Smoke test passed!")
|
||||||
|
else:
|
||||||
|
print("✗ Smoke test failed - differences too large")
|
||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
|
|||||||
)
|
)
|
||||||
file_contents = {}
|
file_contents = {}
|
||||||
for filepath in files:
|
for filepath in files:
|
||||||
if filepath == "<string>":
|
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
|
||||||
|
if not os.path.isfile(filepath):
|
||||||
file_contents[filepath] = ""
|
file_contents[filepath] = ""
|
||||||
else:
|
else:
|
||||||
with open(filepath) as f:
|
with open(filepath) as f:
|
||||||
|
|||||||
@ -20,6 +20,9 @@ from vllm.config.pooler import PoolerConfig
|
|||||||
from vllm.config.scheduler import RunnerType
|
from vllm.config.scheduler import RunnerType
|
||||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
ConfigFormat,
|
ConfigFormat,
|
||||||
@ -419,6 +422,10 @@ class ModelConfig:
|
|||||||
skip_mm_profiling: bool | None,
|
skip_mm_profiling: bool | None,
|
||||||
video_pruning_rate: float | None,
|
video_pruning_rate: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Enable batch invariance settings if requested
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
self.enforce_eager = True
|
||||||
|
|
||||||
# Set the default seed to 0 in V1.
|
# Set the default seed to 0 in V1.
|
||||||
# NOTE(woosuk): In V0, we set the default seed to None because the
|
# NOTE(woosuk): In V0, we set the default seed to None because the
|
||||||
# driver worker shares the same process as the user process, and thus
|
# driver worker shares the same process as the user process, and thus
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from typing_extensions import Self
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
||||||
|
|
||||||
@ -560,7 +563,10 @@ class ParallelConfig:
|
|||||||
def _verify_args(self) -> Self:
|
def _verify_args(self) -> Self:
|
||||||
# Lazy import to avoid circular import
|
# Lazy import to avoid circular import
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
# Enable batch invariance settings if requested
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
self.disable_custom_all_reduce = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.distributed_executor_backend is not None
|
self.distributed_executor_backend is not None
|
||||||
|
|||||||
@ -19,6 +19,9 @@ import torch.multiprocessing as mp
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
|||||||
is_symmetric_memory_enabled,
|
is_symmetric_memory_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
return False
|
||||||
|
|
||||||
if not is_symmetric_memory_enabled():
|
if not is_symmetric_memory_enabled():
|
||||||
return False
|
return False
|
||||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
|||||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -100,6 +103,8 @@ class SymmMemCommunicator:
|
|||||||
return
|
return
|
||||||
self.force_multimem = force_multimem
|
self.force_multimem = force_multimem
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
self.disabled = True
|
||||||
|
|
||||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
|
|||||||
@ -1694,7 +1694,7 @@ class EngineArgs:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Set Default Arguments for V1 Engine."""
|
"""Set Default Arguments for V1 Engine."""
|
||||||
|
|
||||||
# V1 always uses chunked prefills and prefix caching
|
# V1 uses chunked prefills and prefix caching by default
|
||||||
# for non-pooling tasks.
|
# for non-pooling tasks.
|
||||||
# For pooling tasks the default is False
|
# For pooling tasks the default is False
|
||||||
if model_config.runner_type != "pooling":
|
if model_config.runner_type != "pooling":
|
||||||
|
|||||||
@ -395,7 +395,6 @@ def mean_dim(
|
|||||||
Tensor with mean values along specified dimension
|
Tensor with mean values along specified dimension
|
||||||
"""
|
"""
|
||||||
# Validate inputs
|
# Validate inputs
|
||||||
assert input.is_cuda, "Input must be a CUDA tensor"
|
|
||||||
assert -input.ndim <= dim < input.ndim, (
|
assert -input.ndim <= dim < input.ndim, (
|
||||||
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
|
||||||
)
|
)
|
||||||
@ -470,6 +469,45 @@ def mm_batch_invariant(a, b):
|
|||||||
return matmul_persistent(a, b)
|
return matmul_persistent(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def matmul_batch_invariant(a, b, *, out=None):
|
||||||
|
# torch.matmul can handle various dimensions
|
||||||
|
# For 2D x 2D, it's the same as mm
|
||||||
|
if a.ndim == 2 and b.ndim == 2:
|
||||||
|
result = matmul_persistent(a, b)
|
||||||
|
if out is not None:
|
||||||
|
out.copy_(result)
|
||||||
|
return out
|
||||||
|
return result
|
||||||
|
elif a.ndim == 3 and b.ndim == 3:
|
||||||
|
# Handle batched case like bmm
|
||||||
|
return bmm_batch_invariant(a, b, out=out)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, "
|
||||||
|
f"got shapes {a.shape} and {b.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bmm_batch_invariant(a, b, *, out=None):
|
||||||
|
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
|
||||||
|
# Process each batch separately with our persistent kernel
|
||||||
|
if a.ndim == 3 and b.ndim == 3:
|
||||||
|
results = []
|
||||||
|
for i in range(a.shape[0]):
|
||||||
|
results.append(matmul_persistent(a[i], b[i]))
|
||||||
|
result = torch.stack(results, dim=0)
|
||||||
|
|
||||||
|
if out is not None:
|
||||||
|
out.copy_(result)
|
||||||
|
return out
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"bmm_batch_invariant expects 3D tensors, "
|
||||||
|
f"got shapes {a.shape} and {b.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def addmm_batch_invariant(bias, a, b):
|
def addmm_batch_invariant(bias, a, b):
|
||||||
return matmul_persistent(a, b, bias=bias)
|
return matmul_persistent(a, b, bias=bias)
|
||||||
|
|
||||||
@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
|||||||
return log_softmax(input, dim=dim)
|
return log_softmax(input, dim=dim)
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_batch_invariant(input, dim, dtype=None):
|
||||||
|
# Compute softmax in a deterministic way
|
||||||
|
# First subtract max for numerical stability (standard practice)
|
||||||
|
input_max = torch.amax(input, dim=dim, keepdim=True)
|
||||||
|
input = input - input_max
|
||||||
|
exp_x = torch.exp(input)
|
||||||
|
sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True)
|
||||||
|
return exp_x / sum_exp_x
|
||||||
|
|
||||||
|
|
||||||
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||||
|
|
||||||
result = input.to(torch.float32)
|
result = input.to(torch.float32)
|
||||||
|
|
||||||
|
if len(dim) == 0:
|
||||||
|
dim = [i for i in range(len(input.shape))]
|
||||||
|
|
||||||
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
# Sort dimensions to reduce from largest to smallest to handle shifting dims
|
||||||
# during iterative reduction.
|
# during iterative reduction.
|
||||||
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
sorted_dims = sorted([d % input.ndim for d in dim], reverse=True)
|
||||||
@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _rms_norm_kernel(
|
||||||
|
input_ptr,
|
||||||
|
weight_ptr,
|
||||||
|
output_ptr,
|
||||||
|
input_row_stride,
|
||||||
|
output_row_stride,
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute RMS normalization along the last dimension of a 2D tensor.
|
||||||
|
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
|
||||||
|
Each block handles one row of the input tensor.
|
||||||
|
"""
|
||||||
|
row_idx = tl.program_id(0).to(tl.int64)
|
||||||
|
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||||
|
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||||
|
|
||||||
|
# Step 1: Compute sum of squares in float32 to avoid overflow
|
||||||
|
sum_sq = tl.zeros([1], dtype=tl.float32)
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||||
|
# Convert to float32 for accumulation to prevent overflow
|
||||||
|
vals_f32 = vals.to(tl.float32)
|
||||||
|
sq_vals = vals_f32 * vals_f32
|
||||||
|
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
|
||||||
|
|
||||||
|
# Step 2: Compute RMS (root mean square) in float32
|
||||||
|
mean_sq = sum_sq / n_cols
|
||||||
|
rms = tl.sqrt(mean_sq + eps)
|
||||||
|
inv_rms = 1.0 / rms
|
||||||
|
|
||||||
|
# Step 3: Normalize and apply weight
|
||||||
|
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
||||||
|
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = col_idx < n_cols
|
||||||
|
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
||||||
|
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
|
||||||
|
# Compute in float32 then convert back to input dtype
|
||||||
|
vals_f32 = vals.to(tl.float32)
|
||||||
|
weight_f32 = weight.to(tl.float32)
|
||||||
|
output_f32 = vals_f32 * inv_rms * weight_f32
|
||||||
|
output = output_f32.to(vals.dtype)
|
||||||
|
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute RMS normalization using Triton kernel.
|
||||||
|
|
||||||
|
RMS Norm normalizes the input by the root mean square and scales by weight:
|
||||||
|
output = input / sqrt(mean(input^2) + eps) * weight
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor of shape (..., hidden_size)
|
||||||
|
weight: Weight tensor of shape (hidden_size,)
|
||||||
|
eps: Small constant for numerical stability
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor with RMS normalization applied along the last dimension
|
||||||
|
"""
|
||||||
|
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
||||||
|
assert input.shape[-1] == weight.shape[0], (
|
||||||
|
f"Input last dimension ({input.shape[-1]}) must match "
|
||||||
|
f"weight dimension ({weight.shape[0]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten all dimensions except the last one
|
||||||
|
original_shape = input.shape
|
||||||
|
input_2d = input.reshape(-1, input.shape[-1])
|
||||||
|
input_2d = input_2d.contiguous()
|
||||||
|
weight = weight.contiguous()
|
||||||
|
|
||||||
|
n_rows, n_cols = input_2d.shape
|
||||||
|
|
||||||
|
output = torch.empty_like(input_2d)
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
grid = (n_rows,)
|
||||||
|
_rms_norm_kernel[grid](
|
||||||
|
input_2d,
|
||||||
|
weight,
|
||||||
|
output,
|
||||||
|
input_2d.stride(0),
|
||||||
|
output.stride(0),
|
||||||
|
n_cols,
|
||||||
|
eps,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
return output.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm_batch_invariant(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Batch-invariant wrapper for RMS normalization.
|
||||||
|
|
||||||
|
This function provides a deterministic, batch-invariant implementation
|
||||||
|
of RMS normalization for use with the batch_invariant mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: Input tensor of shape (..., hidden_size)
|
||||||
|
weight: Weight tensor of shape (hidden_size,)
|
||||||
|
eps: Small constant for numerical stability
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RMS normalized tensor
|
||||||
|
"""
|
||||||
|
return rms_norm(input, weight, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_batch_invariant(input, weight, bias=None):
|
||||||
|
output = mm_batch_invariant(input, weight.t())
|
||||||
|
if bias is not None:
|
||||||
|
output = output + bias
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
_batch_invariant_MODE = False
|
_batch_invariant_MODE = False
|
||||||
_batch_invariant_LIB = None
|
_batch_invariant_LIB = None
|
||||||
|
_original_torch_bmm = None
|
||||||
|
|
||||||
|
|
||||||
def is_batch_invariant_mode_enabled():
|
def is_batch_invariant_mode_enabled():
|
||||||
@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled():
|
|||||||
|
|
||||||
|
|
||||||
def enable_batch_invariant_mode():
|
def enable_batch_invariant_mode():
|
||||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||||
if _batch_invariant_MODE:
|
if _batch_invariant_MODE:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -517,16 +694,28 @@ def enable_batch_invariant_mode():
|
|||||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||||
_batch_invariant_LIB.impl(
|
_batch_invariant_LIB.impl(
|
||||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||||
)
|
)
|
||||||
|
_batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA")
|
||||||
|
_batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA")
|
||||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||||
|
|
||||||
|
# Also monkeypatch torch.bmm directly as a fallback
|
||||||
|
_original_torch_bmm = torch.bmm
|
||||||
|
torch.bmm = bmm_batch_invariant
|
||||||
|
|
||||||
|
|
||||||
def disable_batch_invariant_mode():
|
def disable_batch_invariant_mode():
|
||||||
global _batch_invariant_MODE, _batch_invariant_LIB
|
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||||
if _batch_invariant_LIB is not None:
|
if _batch_invariant_LIB is not None:
|
||||||
_batch_invariant_LIB._destroy()
|
_batch_invariant_LIB._destroy()
|
||||||
|
if _original_torch_bmm is not None:
|
||||||
|
torch.bmm = _original_torch_bmm
|
||||||
|
_original_torch_bmm = None
|
||||||
_batch_invariant_MODE = False
|
_batch_invariant_MODE = False
|
||||||
_batch_invariant_LIB = None
|
_batch_invariant_LIB = None
|
||||||
|
|
||||||
@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant():
|
|||||||
return is_overridden
|
return is_overridden
|
||||||
|
|
||||||
|
|
||||||
|
def override_envs_for_invariance():
|
||||||
|
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
supported_backends = [
|
||||||
|
"FLASH_ATTN", # best supported backend
|
||||||
|
"FLEX_ATTENTION",
|
||||||
|
"FLASHINFER",
|
||||||
|
"FLASH_ATTN_MLA",
|
||||||
|
"TRITON_MLA",
|
||||||
|
# Not yet supported MLA backends
|
||||||
|
# "FLASHMLA",
|
||||||
|
# "FLASHINFER_MLA",
|
||||||
|
]
|
||||||
|
if curr_attn_backend not in supported_backends:
|
||||||
|
warning = (
|
||||||
|
"Forcibly updating attention backend to"
|
||||||
|
f" {supported_backends[0]} for batch_invariant. "
|
||||||
|
f" Supported backends: {supported_backends}."
|
||||||
|
)
|
||||||
|
logger.warning_once(warning)
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||||
|
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
|
||||||
|
warning = (
|
||||||
|
"You are using a decode-invariant form of batch invariance. "
|
||||||
|
"This will not be invariant between prefill and decode."
|
||||||
|
)
|
||||||
|
logger.warning_once(warning)
|
||||||
|
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||||
|
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
|
||||||
|
# NCCL determinism settings
|
||||||
|
os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
|
||||||
|
os.environ["NCCL_COLLNET_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
||||||
|
os.environ["NCCL_P2P_NET_DISABLE"] = "1"
|
||||||
|
os.environ["NCCL_MIN_NCHANNELS"] = "1"
|
||||||
|
os.environ["NCCL_MAX_NCHANNELS"] = "1"
|
||||||
|
os.environ["NCCL_PROTO"] = "Simple"
|
||||||
|
os.environ["NCCL_ALGO"] = "allreduce:tree"
|
||||||
|
os.environ["NCCL_NTHREADS"] = "1"
|
||||||
|
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def init_batch_invariance():
|
def init_batch_invariance():
|
||||||
# this will hit all the csrc overrides as well
|
# this will hit all the csrc overrides as well
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_kernel_override_batch_invariant():
|
||||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
override_envs_for_invariance()
|
||||||
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
|
|
||||||
if curr_attn_backend not in supported_backends:
|
|
||||||
warning = (
|
|
||||||
"Forcibly updating attention backend to"
|
|
||||||
f" {supported_backends[0]} for batch_invariant. "
|
|
||||||
f" Supported backends: {supported_backends}."
|
|
||||||
)
|
|
||||||
logger.warning_once(warning)
|
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
|
||||||
enable_batch_invariant_mode()
|
enable_batch_invariant_mode()
|
||||||
|
|
||||||
|
# Disable TF32 for batch invariance - it causes non-deterministic rounding
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
|||||||
@ -15,6 +15,9 @@ import vllm.envs as envs
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
@ -837,6 +840,10 @@ def get_moe_configs(
|
|||||||
be picked and the associated configuration chosen to invoke the kernel.
|
be picked and the associated configuration chosen to invoke the kernel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Avoid optimizing for the batch invariant case. Use default config
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
return None
|
||||||
|
|
||||||
# First look up if an optimized configuration is available in the configs
|
# First look up if an optimized configuration is available in the configs
|
||||||
# directory
|
# directory
|
||||||
block_shape = [block_n, block_k] if block_n and block_k else None
|
block_shape = [block_n, block_k] if block_n and block_k else None
|
||||||
@ -969,6 +976,15 @@ def get_default_config(
|
|||||||
dtype: str | None,
|
dtype: str | None,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int] | None = None,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
config = {
|
||||||
|
"BLOCK_SIZE_M": 64,
|
||||||
|
"BLOCK_SIZE_N": 64,
|
||||||
|
"BLOCK_SIZE_K": 32,
|
||||||
|
"GROUP_SIZE_M": 8,
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
if dtype == "fp8_w8a8" and block_shape is not None:
|
if dtype == "fp8_w8a8" and block_shape is not None:
|
||||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
|
||||||
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
# BLOCK_SIZE_K must be divisible by block_shape[1]
|
||||||
@ -1118,7 +1134,10 @@ def fused_topk_bias(
|
|||||||
scores_for_choice = scores.view(
|
scores_for_choice = scores.view(
|
||||||
-1, n_routed_experts
|
-1, n_routed_experts
|
||||||
) + e_score_correction_bias.unsqueeze(0)
|
) + e_score_correction_bias.unsqueeze(0)
|
||||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
|
||||||
|
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||||
|
use_sorted = vllm_kernel_override_batch_invariant()
|
||||||
|
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||||
topk_weights = scores.gather(1, topk_indices)
|
topk_weights = scores.gather(1, topk_indices)
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
@ -1179,7 +1198,10 @@ def grouped_topk(
|
|||||||
group_scores = (
|
group_scores = (
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
) # [n, n_group]
|
) # [n, n_group]
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
|
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||||
|
use_sorted = vllm_kernel_override_batch_invariant()
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||||
1
|
1
|
||||||
] # [n, top_k_group]
|
] # [n, top_k_group]
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
@ -1192,11 +1214,13 @@ def grouped_topk(
|
|||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
|
||||||
|
|
||||||
if e_score_correction_bias is not None:
|
if e_score_correction_bias is not None:
|
||||||
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
|
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||||
# Use original unbiased scores for the routing weights
|
# Use original unbiased scores for the routing weights
|
||||||
topk_weights = original_scores.gather(1, topk_ids)
|
topk_weights = original_scores.gather(1, topk_ids)
|
||||||
else:
|
else:
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
topk_weights, topk_ids = torch.topk(
|
||||||
|
tmp_scores, k=topk, dim=-1, sorted=use_sorted
|
||||||
|
)
|
||||||
|
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|||||||
@ -8,6 +8,10 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
rms_norm_batch_invariant,
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
@ -21,6 +25,8 @@ def rms_norm(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.rms_norm(
|
ops.rms_norm(
|
||||||
out,
|
out,
|
||||||
@ -39,6 +45,10 @@ def fused_add_rms_norm(
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
return rms_norm_batch_invariant(
|
||||||
|
x + residual, weight, variance_epsilon
|
||||||
|
), x + residual
|
||||||
ops.fused_add_rms_norm(
|
ops.fused_add_rms_norm(
|
||||||
x,
|
x,
|
||||||
residual,
|
residual,
|
||||||
|
|||||||
@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
|
|||||||
k_pe,
|
k_pe,
|
||||||
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
|
output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_out)[0]
|
return self.o_proj(attn_out)[0]
|
||||||
|
|
||||||
def forward_cuda(self, *args, **kwargs):
|
def forward_cuda(self, *args, **kwargs):
|
||||||
|
|||||||
@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
FusedMoEActivationFormat,
|
FusedMoEActivationFormat,
|
||||||
@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Disable marlin for rocm
|
# Disable marlin for rocm
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||||
|
|
||||||
@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
# Dequantize FP8 weights to BF16
|
||||||
|
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||||
|
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# Handle different quantization granularities
|
||||||
|
if self.block_quant:
|
||||||
|
# Block-wise quantization:
|
||||||
|
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
|
||||||
|
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
|
||||||
|
assert self.weight_block_size is not None
|
||||||
|
block_n, block_k = self.weight_block_size # Note: order is [N, K]
|
||||||
|
|
||||||
|
N, K = weight_fp8.shape
|
||||||
|
|
||||||
|
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
|
||||||
|
# We need to transpose it to [num_blocks_n, num_blocks_k] first
|
||||||
|
weight_scale = weight_scale.t()
|
||||||
|
|
||||||
|
# Expand scale to match weight dimensions
|
||||||
|
# scale_expanded should have shape [N, K]
|
||||||
|
scale_expanded = weight_scale.repeat_interleave(
|
||||||
|
block_n, dim=0
|
||||||
|
).repeat_interleave(block_k, dim=1)
|
||||||
|
# Trim to exact weight size (in case of padding)
|
||||||
|
scale_expanded = scale_expanded[:N, :K]
|
||||||
|
weight_bf16 = weight_fp8 * scale_expanded
|
||||||
|
else:
|
||||||
|
# Per-tensor quantization: weight IS transposed to [K, N]
|
||||||
|
# scale should be scalar or [1] or per-output-channel [N]
|
||||||
|
if weight_scale.numel() == 1:
|
||||||
|
# Per-tensor: simple scalar multiplication
|
||||||
|
weight_bf16 = weight_fp8 * weight_scale
|
||||||
|
else:
|
||||||
|
# Multiple scales (fused modules like QKV)
|
||||||
|
# Try to infer correct broadcasting
|
||||||
|
# weight is [K, N], scale could be [num_logical_weights]
|
||||||
|
# Need to figure out how to broadcast - for now just try
|
||||||
|
# direct multiplication
|
||||||
|
if (
|
||||||
|
weight_scale.dim() == 1
|
||||||
|
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||||
|
):
|
||||||
|
# Per-row scaling
|
||||||
|
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# Fallback
|
||||||
|
weight_bf16 = weight_fp8 * weight_scale
|
||||||
|
|
||||||
|
# For block quant, weight is [N, K], for per-tensor it's [K, N]
|
||||||
|
# F.linear expects weight to be [N, K], so:
|
||||||
|
if self.block_quant:
|
||||||
|
# Already in correct shape [N, K]
|
||||||
|
output = torch.nn.functional.linear(x, weight_bf16, bias)
|
||||||
|
else:
|
||||||
|
# Need to transpose back: [K, N] -> [N, K]
|
||||||
|
output = torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||||
|
return output
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
return apply_fp8_marlin_linear(
|
return apply_fp8_marlin_linear(
|
||||||
input=x,
|
input=x,
|
||||||
|
|||||||
@ -216,6 +216,7 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.attn(hidden_states, positions)
|
hidden_states = self.attn(hidden_states, positions)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
output = self.mlp(hidden_states)
|
output = self.mlp(hidden_states)
|
||||||
|
|||||||
@ -31,10 +31,12 @@ if is_flash_attn_varlen_func_available():
|
|||||||
get_scheduler_metadata,
|
get_scheduler_metadata,
|
||||||
reshape_and_cache_flash,
|
reshape_and_cache_flash,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.distributed.parallel_state import get_dcp_group
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
@ -306,6 +308,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
# we only set num_splits when using cuda graphs.
|
# we only set num_splits when using cuda graphs.
|
||||||
max_num_splits = self.max_num_splits
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
max_num_splits = 1
|
||||||
|
|
||||||
def schedule(
|
def schedule(
|
||||||
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
||||||
):
|
):
|
||||||
@ -478,6 +483,9 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
|
# Cache the batch invariant result for use in forward passes
|
||||||
|
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
|
||||||
|
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"FlashAttention does not support fp8 kv-cache on this device."
|
"FlashAttention does not support fp8 kv-cache on this device."
|
||||||
@ -810,6 +818,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
q_descale=layer._q_scale.expand(descale_shape),
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
num_splits=1 if self.batch_invariant_enabled else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
@ -954,6 +963,7 @@ def cascade_attention(
|
|||||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||||
# enabling its effect during the final attention merge.
|
# enabling its effect during the final attention merge.
|
||||||
s_aux=s_aux,
|
s_aux=s_aux,
|
||||||
|
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||||
@ -978,6 +988,7 @@ def cascade_attention(
|
|||||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||||
|
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge prefix and suffix outputs, and store the result in output.
|
# Merge prefix and suffix outputs, and store the result in output.
|
||||||
|
|||||||
@ -211,6 +211,9 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version
|
|||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
LinearBase,
|
LinearBase,
|
||||||
@ -1187,6 +1190,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
|
|
||||||
if is_rocm_aiter_fp8bmm_enabled():
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||||
x = aiter_triton_fp8_bmm(
|
x = aiter_triton_fp8_bmm(
|
||||||
@ -1279,6 +1283,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# ROCm leverages the upstream flash_attn, which takes a parameter
|
# ROCm leverages the upstream flash_attn, which takes a parameter
|
||||||
# called "return_attn_probs" instead of return_softmax_lse
|
# called "return_attn_probs" instead of return_softmax_lse
|
||||||
kwargs["return_attn_probs"] = return_softmax_lse
|
kwargs["return_attn_probs"] = return_softmax_lse
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
kwargs["num_splits"] = 1
|
||||||
|
|
||||||
attn_out = self.flash_attn_varlen_func(
|
attn_out = self.flash_attn_varlen_func(
|
||||||
q=q,
|
q=q,
|
||||||
@ -1841,9 +1847,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
decode_q_nope, decode_q_pe = decode_q.split(
|
decode_q_nope, decode_q_pe = decode_q.split(
|
||||||
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert from (B, N, P) to (N, B, P)
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||||
|
|
||||||
@ -1868,17 +1876,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# Pads the head_dim if necessary (for the underlying kernel)
|
# Pads the head_dim if necessary (for the underlying kernel)
|
||||||
N, B, P = decode_q_nope.shape
|
N, B, P = decode_q_nope.shape
|
||||||
_, _, L = self.W_UK_T.shape
|
_, _, L = self.W_UK_T.shape
|
||||||
|
|
||||||
if self.q_pad_num_heads is not None:
|
if self.q_pad_num_heads is not None:
|
||||||
decode_ql_nope = decode_q_nope.new_empty(
|
decode_ql_nope = decode_q_nope.new_empty(
|
||||||
(self.q_pad_num_heads, B, L)
|
(self.q_pad_num_heads, B, L)
|
||||||
)
|
)
|
||||||
decode_ql_nope.resize_((N, B, L))
|
decode_ql_nope.resize_((N, B, L))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
||||||
|
|
||||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
||||||
|
|
||||||
# Convert from (N, B, L) to (B, N, L)
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
|
|
||||||
|
|||||||
@ -18,6 +18,9 @@ from vllm.attention.utils.fa_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
MLACommonDecodeMetadata,
|
MLACommonDecodeMetadata,
|
||||||
@ -107,6 +110,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
# pre-allocated during capture.
|
# pre-allocated during capture.
|
||||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||||
|
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
self.max_num_splits = 1
|
||||||
|
|
||||||
def _schedule_decode(
|
def _schedule_decode(
|
||||||
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
||||||
):
|
):
|
||||||
@ -175,7 +181,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
# we only set num_splits when using cuda graphs.
|
# we only set num_splits when using cuda graphs.
|
||||||
max_num_splits = self.max_num_splits
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
return FlashAttnMLADecodeMetadata(
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
max_num_splits = 1
|
||||||
|
|
||||||
|
metadata = FlashAttnMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens_device,
|
seq_lens=seq_lens_device,
|
||||||
query_start_loc=query_start_loc_device,
|
query_start_loc=query_start_loc_device,
|
||||||
@ -185,6 +194,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
max_num_splits=max_num_splits,
|
max_num_splits=max_num_splits,
|
||||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||||
)
|
)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from vllm.attention.ops.flashmla import (
|
|||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
MLACommonDecodeMetadata,
|
MLACommonDecodeMetadata,
|
||||||
@ -223,19 +226,50 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
if type(q) is tuple:
|
if type(q) is tuple:
|
||||||
q = torch.cat(q, dim=-1)
|
q = torch.cat(q, dim=-1)
|
||||||
|
|
||||||
|
# mypy assertion: q is now always a tensor
|
||||||
assert isinstance(q, torch.Tensor)
|
assert isinstance(q, torch.Tensor)
|
||||||
|
|
||||||
num_decodes = attn_metadata.num_decodes
|
num_decodes = attn_metadata.num_decodes
|
||||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||||
|
|
||||||
|
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||||
|
num_splits = attn_metadata.decode.num_splits
|
||||||
|
if vllm_kernel_override_batch_invariant():
|
||||||
|
device = q.device
|
||||||
|
dtype = torch.int32
|
||||||
|
|
||||||
|
B = q.shape[0]
|
||||||
|
# block_table shape: [batch_size, max_num_blocks_per_seq]
|
||||||
|
# The number of blocks per sequence is in the second dimension
|
||||||
|
topk = attn_metadata.decode.block_table.shape[-1]
|
||||||
|
B_TOPK = 64
|
||||||
|
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
|
||||||
|
end_block_idx = topk // B_TOPK
|
||||||
|
|
||||||
|
# Single partition => num_sm_parts = 1
|
||||||
|
# TileSchedulerMetaDataSize = 8, layout:
|
||||||
|
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
|
||||||
|
# begin_n_split_idx, _, _, _]
|
||||||
|
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
|
||||||
|
tile_scheduler_metadata[0, 0] = 0 # begin_idx
|
||||||
|
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
|
||||||
|
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
|
||||||
|
tile_scheduler_metadata[0, 3] = end_block_idx
|
||||||
|
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
|
||||||
|
# fields [5..7] stay 0
|
||||||
|
|
||||||
|
# Non-split path ignores num_splits, but the API requires it:
|
||||||
|
# zeros of length B+1
|
||||||
|
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||||
|
|
||||||
o, lse = flash_mla_with_kvcache(
|
o, lse = flash_mla_with_kvcache(
|
||||||
q=q,
|
q=q,
|
||||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||||
block_table=attn_metadata.decode.block_table,
|
block_table=attn_metadata.decode.block_table,
|
||||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||||
head_dim_v=self.kv_lora_rank,
|
head_dim_v=self.kv_lora_rank,
|
||||||
tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata,
|
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||||
num_splits=attn_metadata.decode.num_splits,
|
num_splits=num_splits,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
descale_q=layer._q_scale.reshape(1),
|
descale_q=layer._q_scale.reshape(1),
|
||||||
|
|||||||
@ -13,6 +13,9 @@ from vllm.attention.backends.abstract import (
|
|||||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_kernel_override_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
@ -158,7 +161,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
||||||
)
|
)
|
||||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||||
num_kv_splits = 4 # TODO: heuristic
|
|
||||||
|
# For batch invariance, use only 1 split to ensure deterministic reduction
|
||||||
|
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
|
||||||
|
|
||||||
# TODO(lucas) Allocate ahead of time
|
# TODO(lucas) Allocate ahead of time
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
|
|||||||
@ -231,9 +231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||||
|
|
||||||
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
|
set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||||
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
|
|
||||||
|
|
||||||
init_batch_invariance()
|
|
||||||
|
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
cache_config = self.cache_config
|
cache_config = self.cache_config
|
||||||
|
|||||||
@ -764,6 +764,9 @@ def init_worker_distributed_environment(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
|
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
|
||||||
|
|
||||||
|
init_batch_invariance()
|
||||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||||
|
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user