From 7d8975de84dadb87df23eae4a663a3c6620433dd Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Wed, 15 Oct 2025 22:06:02 -0700 Subject: [PATCH] Deepseek-v3 Batch Invariant on 8xH100 (#26609) Signed-off-by: Bram Wasti Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- tests/v1/generation/test_batch_invariance.py | 842 ++++++++++++++++-- .../test_rms_norm_batch_invariant.py | 346 +++++++ vllm/compilation/caching.py | 4 +- vllm/config/model.py | 7 + vllm/config/parallel.py | 8 +- .../device_communicators/all_reduce_utils.py | 6 + .../device_communicators/symm_mem.py | 5 + vllm/engine/arg_utils.py | 2 +- vllm/model_executor/layers/batch_invariant.py | 253 +++++- .../layers/fused_moe/fused_moe.py | 32 +- vllm/model_executor/layers/layernorm.py | 10 + vllm/model_executor/layers/mla.py | 1 + .../model_executor/layers/quantization/fp8.py | 65 ++ vllm/model_executor/models/gpt_oss.py | 1 + vllm/v1/attention/backends/flash_attn.py | 13 +- vllm/v1/attention/backends/mla/common.py | 11 +- .../attention/backends/mla/flashattn_mla.py | 12 +- vllm/v1/attention/backends/mla/flashmla.py | 38 +- vllm/v1/attention/backends/mla/triton_mla.py | 7 +- vllm/v1/worker/gpu_model_runner.py | 3 - vllm/v1/worker/gpu_worker.py | 3 + 21 files changed, 1567 insertions(+), 102 deletions(-) create mode 100644 tests/v1/generation/test_rms_norm_batch_invariant.py diff --git a/tests/v1/generation/test_batch_invariance.py b/tests/v1/generation/test_batch_invariance.py index c9989a7ebe8a8..6fe7c42df2830 100644 --- a/tests/v1/generation/test_batch_invariance.py +++ b/tests/v1/generation/test_batch_invariance.py @@ -3,56 +3,73 @@ import contextlib import os import random -import string import pytest import torch from vllm import LLM, SamplingParams +from vllm.platforms import current_platform + + +@pytest.fixture(autouse=True) +def enable_batch_invariant_mode(): + """Automatically enable batch invariant kernel overrides for all tests.""" + old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") + os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1" + yield + # Restore original value after test + if old_value is None: + os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) + else: + os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: - # Lightweight random prompt generator to vary prompt lengths and content. - vocab = [ - "alpha", - "bravo", - "charlie", - "delta", - "echo", - "foxtrot", - "golf", - "hotel", - "india", - "juliet", - "kilo", - "lima", - "mike", - "november", - "oscar", - "papa", - "quebec", - "romeo", - "sierra", - "tango", - "uniform", - "victor", - "whiskey", - "xray", - "yankee", - "zulu", + # Generate more realistic prompts that will actually produce varied tokens + # Use a mix of common English text patterns + + prompt_templates = [ + # Question-answer style + "Question: What is the capital of France?\nAnswer: The capital of France is", + "Q: How does photosynthesis work?\nA: Photosynthesis is the process by which", + "User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is", + # Story/narrative style + "Once upon a time in a distant galaxy, there lived", + "The old man walked slowly down the street, remembering", + "In the year 2157, humanity finally discovered", + # Technical/code style + "To implement a binary search tree in Python, first we need to", + "The algorithm works by iterating through the array and", + "Here's how to optimize database queries using indexing:", + # Factual/informative style + "The Renaissance was a period in European history that", + "Climate change is caused by several factors including", + "The human brain contains approximately 86 billion neurons which", + # Conversational style + "I've been thinking about getting a new laptop because", + "Yesterday I went to the store and bought", + "My favorite thing about summer is definitely", ] - n = random.randint(min_words, max_words) - words = random.choices(vocab, k=n) - # Add some noise and punctuation variability - if random.random() < 0.5: - words[0] = words[0].capitalize() - if random.random() < 0.2: - words.append("".join(random.choices(string.ascii_lowercase, k=5))) - punct = random.choice([".", "?", "!", "...", ""]) - return " ".join(words) + punct + # Pick a random template + base_prompt = random.choice(prompt_templates) + + # Add some padding to vary the length if needed + if min_words > 50: + # For longer prompts, repeat context + padding_text = ( + " This is an interesting topic that deserves more explanation. " + * (min_words // 50) + ) + base_prompt = base_prompt + padding_text + + return base_prompt +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) @pytest.mark.timeout(1000) def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): """ @@ -91,7 +108,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): # Keep GPU memory usage low to avoid startup allocation failures. gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4")) max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120")) - swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) # Sampling parameters: longer outputs with a more random-sounding # continuation,but still deterministic due to fixed seed. @@ -117,7 +133,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): max_num_seqs=max_batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, - swap_space=swap_space_gb, ) # Baseline generation for the needle prompt alone. @@ -132,7 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(): max_num_seqs=max_batch_size, gpu_memory_utilization=gpu_mem_util, max_model_len=max_model_len, - swap_space=swap_space_gb, ) mismatches = 0 @@ -195,16 +209,21 @@ def _extract_step_logprobs(request_output): ], dtype=torch.float32, ) - return t + return t, inner.token_ids - return None + return None, None +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) @pytest.mark.skipif( not torch.cuda.is_available(), reason="Requires CUDA to match production inference path.", ) -@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.forked def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) os.environ["VLLM_ATTENTION_BACKEND"] = backend @@ -214,69 +233,754 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) - # Force float32 to avoid precision-induced differences. + # For batch invariance, disable custom all-reduce to ensure deterministic + # all-reduce operations (custom all-reduce may not be deterministic) + from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, + ) + + disable_custom_ar = vllm_kernel_override_batch_invariant() + + if disable_custom_ar: + print(f"\n{'=' * 80}") + print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})") + print(f"{'=' * 80}\n") + llm = LLM( model=model_name, tensor_parallel_size=tp_size, - enforce_eager=True, enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", # not everything is supported ) - prompts = [_random_prompt(10, 1024) for i in range(100)] + # Use more realistic prompts for better token generation + prompts = [_random_prompt(10, 50) for i in range(32)] sp = SamplingParams( temperature=0.6, top_p=1.0, max_tokens=8, - # Seed shouldn't matter at temperature=0, but keeping it stable anyway. seed=1234, logprobs=5, ) # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + bs1_logprobs_per_prompt = [] - for p in prompts: + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print(f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}...") outs = llm.generate([p], sp, use_tqdm=False) assert len(outs) == 1 - step_logprobs = _extract_step_logprobs(outs[0]) + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) if step_logprobs is None: pytest.skip( "Logits are not available on RequestOutput; " "enable logprobs return to run this test." ) bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") # BS=N: run prompts in a batch and collect logprobs per step for each # prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + outs_batched = llm.generate(prompts, sp, use_tqdm=False) assert len(outs_batched) == len(prompts) bsN_logprobs_per_prompt = [] - for o in outs_batched: - step_logprobs = _extract_step_logprobs(o) + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) if step_logprobs is None: pytest.skip( "Logits are not available on RequestOutput; " "enable logprobs return to run this test." ) bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. - for i, (logprobs_bs1, logprobs_bsN) in enumerate( - zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt) - ): - assert len(logprobs_bs1) == len(logprobs_bsN), ( - f"Different number of generation steps for prompt index {i}: " - f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)" + failed_prompts = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, ) + ): + if len(logprobs_bs1) != len(logprobs_bsN): + reason = ( + f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)" + ) + failed_prompts.append( + { + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + failed_prompts.append( + { + "prompt_idx": i, + "step": "sampling", + "reason": "Different tokens sampled", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + "bs1_all_logprobs": [ + logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1)) + ], + "bsN_all_logprobs": [ + logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN)) + ], + } + ) + continue + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): - assert a.shape == b.shape, ( - f"Logits shape mismatch at prompt {i}, step {t}: {a.shape} vs {b.shape}" + if a.shape != b.shape: + failed_prompts.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + # Print which token failed + print(f"\n[DIVERGENCE] Prompt {i}, Token {t}: max_diff={max_diff:.6e}") + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + failed_prompts.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + "bs1_all_logprobs": [ + logprobs_bs1[s].tolist() for s in range(len(logprobs_bs1)) + ], + "bsN_all_logprobs": [ + logprobs_bsN[s].tolist() for s in range(len(logprobs_bsN)) + ], + } + ) + break + + # Print summary of all failures + if failed_prompts: + print(f"\n{'=' * 80}") + fail_msg = ( + f"BATCH INVARIANCE FAILURES: {len(failed_prompts)}/" + f"{len(prompts)} prompts failed" + ) + print(fail_msg) + print(f"{'=' * 80}") + for fail in failed_prompts: + print(f"\nPrompt {fail['prompt_idx']} (step {fail['step']}):") + print(f" Reason: {fail['reason']}") + print(f" Preview: {fail['prompt_preview']}...") + + # Always show the tokens + if "bs1_tokens" in fail: + print(f" BS=1 tokens: {fail['bs1_tokens']}") + if "bsN_tokens" in fail: + print(f" BS=N tokens: {fail['bsN_tokens']}") + + if "bs1_all_logprobs" in fail: + print(f" BS=1 logprobs for all {len(fail['bs1_all_logprobs'])} steps:") + for step_idx, logprobs in enumerate(fail["bs1_all_logprobs"]): + print(f" Step {step_idx}: {logprobs}") + print(f" BS=N logprobs for all {len(fail['bsN_all_logprobs'])} steps:") + for step_idx, logprobs in enumerate(fail["bsN_all_logprobs"]): + print(f" Step {step_idx}: {logprobs}") + print(f"{'=' * 80}\n") + + # Fail the test with summary + msg = ( + f"Batch invariance violated in {len(failed_prompts)}/" + f"{len(prompts)} prompts. See output above for details." + ) + pytest.fail(msg) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +def test_simple_generation(): + """ + Simple test that runs the model with a basic prompt and prints the output. + Useful for quick smoke testing and debugging. + """ + model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + + llm = LLM( + model=model, + max_num_seqs=1, + tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), + enforce_eager=True, + gpu_memory_utilization=0.9, + max_model_len=2048, + dtype="bfloat16", + enable_prefix_caching=False, + ) + + prompt = "the capital of france is" + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=20, + ) + + print(f"\n{'=' * 80}") + print("Running simple generation test") + print(f"Prompt: '{prompt}'") + print(f"{'=' * 80}\n") + + try: + outputs = llm.generate([prompt], sampling_params) + + assert len(outputs) == 1 + output_text = outputs[0].outputs[0].text + + print(f"Output: '{output_text}'") + print(f"\n{'=' * 80}") + print(f"Full completion: '{prompt}{output_text}'") + print(f"{'=' * 80}\n") + + finally: + with contextlib.suppress(Exception): + llm.shutdown() + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Requires CUDA to match production inference path.", +) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER"]) +@pytest.mark.forked +def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): + """ + This test is the inverse of test_logprobs_bitwise_batch_invariance_bs1_vs_bsN. + It DISABLES batch invariance mode and expects to see non-deterministic behavior + between BS=1 and BS=N runs. This demonstrates that batch invariance is actually + doing something useful. + + The test will PASS if we detect differences (proving batch invariance matters). + The test will FAIL if everything matches (suggesting batch invariance isn't needed). + """ + backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) + os.environ["VLLM_ATTENTION_BACKEND"] = backend + + # CRITICAL: Disable batch invariance for this test + old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") + os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0" + + try: + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + print(f"\n{'=' * 80}") + print("BATCH INVARIANCE DISABLED: Expecting non-deterministic behavior") + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + ) + + # Use more realistic prompts for better token generation + prompts = [_random_prompt(10, 50) for i in range(32)] + + sp = SamplingParams( + temperature=0.6, + top_p=1.0, + max_tokens=8, + seed=1234, + logprobs=5, + ) + + # BS=1: run prompts individually and collect logprobs per step. + print("\n" + "=" * 80) + print("STARTING BS=1 RUNS (each prompt individually)") + print("=" * 80 + "\n") + + bs1_logprobs_per_prompt = [] + bs1_tokens_per_prompt = [] + for idx, p in enumerate(prompts): + print( + f"\n[BS=1] Running prompt {idx}/{len(prompts)} - Preview: {p[:80]}..." ) - # Bitwise exact equality. - assert torch.equal(a, b), ( - f"Bitwise logprobs mismatch at prompt {i}, step {t} " - f"(dtype={a.dtype}, shape={a.shape})." + outs = llm.generate([p], sp, use_tqdm=False) + assert len(outs) == 1 + step_logprobs, token_ids = _extract_step_logprobs(outs[0]) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bs1_logprobs_per_prompt.append(step_logprobs) + bs1_tokens_per_prompt.append(token_ids) + print(f"[BS=1] Prompt {idx} generated tokens: {token_ids}") + + # BS=N: run prompts in a batch and collect logprobs per step for each prompt. + print("\n" + "=" * 80) + print(f"STARTING BS={len(prompts)} RUN (all prompts batched)") + print("=" * 80 + "\n") + + outs_batched = llm.generate(prompts, sp, use_tqdm=False) + assert len(outs_batched) == len(prompts) + bsN_logprobs_per_prompt = [] + bsN_tokens_per_prompt = [] + + print(f"\n[BS={len(prompts)}] Processing batched outputs...") + for idx, o in enumerate(outs_batched): + tokens = o.outputs[0].token_ids if o.outputs else "N/A" + print(f"[BS={len(prompts)}] Prompt {idx} generated tokens: {tokens}") + step_logprobs, token_ids = _extract_step_logprobs(o) + if step_logprobs is None: + pytest.skip( + "Logits are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + bsN_logprobs_per_prompt.append(step_logprobs) + bsN_tokens_per_prompt.append(token_ids) + + # Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. + differences_found = [] + for i, (logprobs_bs1, logprobs_bsN, tokens_bs1, tokens_bsN) in enumerate( + zip( + bs1_logprobs_per_prompt, + bsN_logprobs_per_prompt, + bs1_tokens_per_prompt, + bsN_tokens_per_prompt, ) + ): + if len(logprobs_bs1) != len(logprobs_bsN): + reason = ( + f"Different number of steps: {len(logprobs_bs1)} (BS=1) " + f"vs {len(logprobs_bsN)} (BS=N)" + ) + differences_found.append( + { + "prompt_idx": i, + "step": "all", + "reason": reason, + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + # Check if tokens match first + if tokens_bs1 != tokens_bsN: + differences_found.append( + { + "prompt_idx": i, + "step": "sampling", + "reason": "Different tokens sampled", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + continue + + for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): + if a.shape != b.shape: + differences_found.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Shape mismatch: {a.shape} vs {b.shape}", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + if not torch.equal(a, b): + max_diff = torch.abs(a - b).max().item() + print( + f"\n[EXPECTED DIVERGENCE FOUND] Prompt {i}, " + f"Token {t}: max_diff={max_diff:.6e}" + ) + bs1_tok = tokens_bs1[t] if t < len(tokens_bs1) else "N/A" + bsN_tok = tokens_bsN[t] if t < len(tokens_bsN) else "N/A" + print(f" Token IDs: bs1={bs1_tok}, bsN={bsN_tok}") + print(f" BS=1 logprob: {a.tolist()}") + print(f" BS=N logprob: {b.tolist()}") + differences_found.append( + { + "prompt_idx": i, + "step": t, + "reason": f"Bitwise mismatch (max_diff={max_diff:.6e})", + "prompt_preview": prompts[i][:100], + "bs1_tokens": tokens_bs1, + "bsN_tokens": tokens_bsN, + } + ) + break + + # Print summary + print(f"\n{'=' * 80}") + if differences_found: + success_msg = ( + f"✓ SUCCESS: Batch invariance is doing something! " + f"Found {len(differences_found)}/{len(prompts)} prompts " + f"with differences when batch invariance was DISABLED." + ) + print(success_msg) + print(f"{'=' * 80}") + for diff in differences_found: + print(f"\nPrompt {diff['prompt_idx']} (step {diff['step']}):") + print(f" Reason: {diff['reason']}") + print(f" Preview: {diff['prompt_preview']}...") + if "bs1_tokens" in diff: + print(f" BS=1 tokens: {diff['bs1_tokens']}") + if "bsN_tokens" in diff: + print(f" BS=N tokens: {diff['bsN_tokens']}") + print(f"{'=' * 80}\n") + # Test PASSES because we found differences (batch invariance matters!) + return + else: + # Test FAILS because everything matched even without batch invariance + fail_msg = ( + f"✗ UNEXPECTED: All {len(prompts)} prompts matched " + f"between BS=1 and BS=N even with batch invariance DISABLED. " + f"This suggests batch invariance might not be necessary, " + f"or the test needs more sensitive prompts." + ) + print(fail_msg) + print(f"{'=' * 80}\n") + pytest.fail(fail_msg) + + finally: + # Restore original value + if old_value is None: + os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) + else: + os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="Requires CUDA to match production inference path.", +) +@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) +@pytest.mark.forked +def test_decode_logprobs_match_prefill_logprobs(backend): + """ + Test that verifies decode logprobs match prefill logprobs. + + For each decoded token at position i: + 1. Run decode to generate N tokens and collect their logprobs + 2. For each position i in [0, N): + - Take prefix = prompt + tokens[0:i] + - Run prefill(prefix + tokens[i]) to get logprob of tokens[i] + - Verify prefill logprob matches decode logprob bitwise + + This ensures that the logprobs from decode are consistent with what + we would get if we ran prefill on each prefix. + """ + backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) + os.environ["VLLM_ATTENTION_BACKEND"] = backend + + seed = int(os.getenv("VLLM_TEST_SEED", "12345")) + random.seed(seed) + model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") + tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) + + from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, + ) + + disable_custom_ar = vllm_kernel_override_batch_invariant() + + if disable_custom_ar: + print(f"\n{'=' * 80}") + print(f"BATCH INVARIANCE MODE: Disabling custom all-reduce (TP={tp_size})") + print(f"{'=' * 80}\n") + + llm = LLM( + model=model_name, + tensor_parallel_size=tp_size, + enable_prefix_caching=False, + max_num_seqs=32, + max_model_len=8192, + dtype="bfloat16", + ) + + # Use a few test prompts + num_test_prompts = int(os.getenv("VLLM_DECODE_PREFILL_NUM_PROMPTS", "4")) + prompts = [_random_prompt(10, 50) for _ in range(num_test_prompts)] + + # Generate longer sequences to test multiple decode steps + max_tokens = int(os.getenv("VLLM_DECODE_PREFILL_MAX_TOKENS", "16")) + + sp = SamplingParams( + temperature=0.0, # Greedy for determinism + max_tokens=max_tokens, + logprobs=5, + ) + + print("\n" + "=" * 80) + print("STEP 1: Running decode to generate tokens and collect logprobs") + print("=" * 80 + "\n") + + # Step 1: Run decode and collect logprobs + decode_outputs = llm.generate(prompts, sp, use_tqdm=False) + + failed_comparisons = [] + + for prompt_idx, (prompt, decode_output) in enumerate(zip(prompts, decode_outputs)): + print(f"\n[Prompt {prompt_idx}] Testing: {prompt[:80]}...") + + # Extract decode logprobs and tokens + decode_logprobs, token_ids = _extract_step_logprobs(decode_output) + if decode_logprobs is None: + pytest.skip( + "Logprobs are not available on RequestOutput; " + "enable logprobs return to run this test." + ) + + print(f"[Prompt {prompt_idx}] Generated {len(token_ids)} tokens: {token_ids}") + print(f"[Prompt {prompt_idx}] Decode logprobs: {decode_logprobs.tolist()}") + + # Step 2: For each token position, run prefill and compare + print(f"\n[Prompt {prompt_idx}] Verifying each token via prefill...") + + for token_idx in range(len(token_ids)): + # Construct the prefix up to (but not including) this token + current_token = token_ids[token_idx] + + # We need to detokenize to get the text prefix + # For this, we'll use the tokenizer from the LLM + # However, the LLM API doesn't expose tokenizer easily, so we'll + # construct the prefix by decoding from the original prompt + + # Get text up to this point by using the output text + # This is approximate but should work for verification + if token_idx == 0: + prefix_prompt = prompt + else: + # Use the partial output text up to this token + # We'll need to construct this from the full output + prefix_output = decode_output.outputs[0] + # Get the text for tokens 0 to token_idx-1 + # Unfortunately, we don't have per-token text, so we'll use + # a different approach: run prefill with prompt + tokens[0:token_idx] + + # Actually, we need to get the actual text. Let's use a workaround: + # Run a generation with max_tokens = token_idx to get that prefix + prefix_sp = SamplingParams( + temperature=0.0, + max_tokens=token_idx, + logprobs=1, + ) + prefix_output = llm.generate([prompt], prefix_sp, use_tqdm=False)[0] + prefix_prompt = prompt + prefix_output.outputs[0].text + + # Now run prefill with max_tokens=1 to get the logprob of the next token + prefill_sp = SamplingParams( + temperature=0.0, + max_tokens=1, + logprobs=5, + ) + + print( + f" [Token {token_idx}] Running prefill for prefix " + f"(len={len(prefix_prompt)})..." + ) + prefill_output = llm.generate([prefix_prompt], prefill_sp, use_tqdm=False)[ + 0 + ] + prefill_logprobs, prefill_token_ids = _extract_step_logprobs(prefill_output) + + if prefill_logprobs is None: + print(f" [Token {token_idx}] Warning: No prefill logprobs available") + continue + + # The first token from prefill should match the current token + prefill_token = prefill_token_ids[0] + prefill_logprob = prefill_logprobs[0].item() + decode_logprob = decode_logprobs[token_idx].item() + + print( + f" [Token {token_idx}] Decode token: {current_token}, " + f"logprob: {decode_logprob:.8f}" + ) + print( + f" [Token {token_idx}] Prefill token: {prefill_token}, " + f"logprob: {prefill_logprob:.8f}" + ) + + # Check if tokens match + if current_token != prefill_token: + failed_comparisons.append( + { + "prompt_idx": prompt_idx, + "token_idx": token_idx, + "reason": "Token mismatch", + "decode_token": current_token, + "prefill_token": prefill_token, + "decode_logprob": decode_logprob, + "prefill_logprob": prefill_logprob, + "prompt_text": prompt[:100], + "prefix_text": prefix_prompt[:100], + } + ) + print(f" [Token {token_idx}] ✗ TOKEN MISMATCH!") + continue + + # Check if logprobs match bitwise + if decode_logprob != prefill_logprob: + diff = abs(decode_logprob - prefill_logprob) + failed_comparisons.append( + { + "prompt_idx": prompt_idx, + "token_idx": token_idx, + "reason": "Logprob mismatch", + "decode_token": current_token, + "prefill_token": prefill_token, + "decode_logprob": decode_logprob, + "prefill_logprob": prefill_logprob, + "diff": diff, + "prompt_text": prompt[:100], + "prefix_text": prefix_prompt[:100], + "decode_all_tokens": token_ids, + "decode_all_logprobs": decode_logprobs.tolist(), + } + ) + print(f" [Token {token_idx}] ✗ LOGPROB MISMATCH! diff={diff:.8e}") + else: + print(f" [Token {token_idx}] ✓ Match (bitwise equal)") + + # Print summary + print(f"\n{'=' * 80}") + if failed_comparisons: + print(f"DECODE-PREFILL MISMATCH: {len(failed_comparisons)} failures detected") + print(f"{'=' * 80}") + + # Group failures by prompt for better readability + failures_by_prompt: dict[int, list[dict]] = {} + for fail in failed_comparisons: + pid = fail["prompt_idx"] + if pid not in failures_by_prompt: + failures_by_prompt[pid] = [] + failures_by_prompt[pid].append(fail) + + for prompt_idx, failures in failures_by_prompt.items(): + print(f"\n{'=' * 80}") + print(f"PROMPT {prompt_idx}: {failures[0]['prompt_text']}...") + print(f"{'=' * 80}") + print(f"Total failures for this prompt: {len(failures)}") + + # Show where mismatches occur (which token positions) + mismatch_positions = [f["token_idx"] for f in failures] + print(f"Mismatch at token positions: {mismatch_positions}") + + # Show first few failures in detail + for i, fail in enumerate(failures[:5]): # Show first 5 failures per prompt + print(f"\n [Failure {i + 1}] Token position {fail['token_idx']}:") + print(f" Reason: {fail['reason']}") + print(f" Prefix text: '{fail['prefix_text']}...'") + print( + f" Decode: token={fail['decode_token']}, " + f"logprob={fail['decode_logprob']:.10f}" + ) + print( + f" Prefill: token={fail['prefill_token']}, " + f"logprob={fail['prefill_logprob']:.10f}" + ) + if "diff" in fail: + print(f" Difference: {fail['diff']:.10e}") + # Show in hex to see bitwise difference + import struct + + decode_hex = struct.pack("f", fail["decode_logprob"]).hex() + prefill_hex = struct.pack("f", fail["prefill_logprob"]).hex() + print(f" Decode logprob (hex): 0x{decode_hex}") + print(f" Prefill logprob (hex): 0x{prefill_hex}") + + # If we have all tokens/logprobs, show the context + if "decode_all_tokens" in fail and "decode_all_logprobs" in fail: + token_idx = fail["token_idx"] + all_tokens = fail["decode_all_tokens"] + all_logprobs = fail["decode_all_logprobs"] + + # Show context: 2 tokens before and after + start = max(0, token_idx - 2) + end = min(len(all_tokens), token_idx + 3) + + print(f" Context (tokens {start} to {end - 1}):") + for j in range(start, end): + marker = " <-- MISMATCH" if j == token_idx else "" + print( + f" [{j}] token={all_tokens[j]}, " + f"logprob={all_logprobs[j]:.8f}{marker}" + ) + + if len(failures) > 5: + print(f"\n ... and {len(failures) - 5} more failures for this prompt") + + print(f"\n{'=' * 80}\n") + + pytest.fail( + f"Decode logprobs do not match prefill logprobs: " + f"{len(failed_comparisons)} mismatches found." + ) + else: + print("✓ SUCCESS: All decode logprobs match prefill logprobs bitwise!") + print(f"{'=' * 80}\n") def LLM_with_max_seqs( @@ -284,7 +988,6 @@ def LLM_with_max_seqs( max_num_seqs: int, gpu_memory_utilization: float, max_model_len: int, - swap_space: int, ) -> LLM: """ Helper to construct an LLM with a specific max_num_seqs (batch-size limit) @@ -293,17 +996,10 @@ def LLM_with_max_seqs( return LLM( model=model, max_num_seqs=max_num_seqs, - # Constrain GPU memory pool so test can run even on busy GPUs. gpu_memory_utilization=gpu_memory_utilization, - # Keep KV cache footprint small while allowing longer outputs. max_model_len=max_model_len, - # Allow some CPU offload if needed. - swap_space=swap_space, - # Keep things lean and CI-friendly. - dtype="auto", - # Single-GPU by default; override externally if desired. + dtype="bfloat16", tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), - trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1", enable_prefix_caching=False, enforce_eager=True, # Enable for MOE models diff --git a/tests/v1/generation/test_rms_norm_batch_invariant.py b/tests/v1/generation/test_rms_norm_batch_invariant.py new file mode 100644 index 0000000000000..12d960362430b --- /dev/null +++ b/tests/v1/generation/test_rms_norm_batch_invariant.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test batch-invariant RMS normalization against standard implementations. + +This test compares the Triton-based batch-invariant RMS norm implementation +with the standard CUDA-based implementation to ensure numerical accuracy. +""" + +import pytest +import torch + +from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels" +) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 64]) +@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("eps", [1e-6, 1e-5]) +def test_rms_norm_batch_invariant_vs_standard( + batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float +): + """ + Compare batch-invariant Triton RMS norm against standard CUDA implementation. + + Tests that the Triton-based batch-invariant RMS norm produces numerically + equivalent results to the standard CUDA implementation across various + configurations. + """ + device = torch.device("cuda") + + # Create test input and weight + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation (CUDA ops) + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation (Triton) + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Compare outputs + # Use looser tolerance for bfloat16 due to its lower precision + if dtype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16 + else: + rtol, atol = 1e-2, 1e-2 # 1% for float16/float32 + + torch.testing.assert_close( + triton_output, + standard_output, + rtol=rtol, + atol=atol, + msg=f"RMS norm mismatch for batch_size={batch_size}, " + f"hidden_size={hidden_size}, " + f"dtype={dtype}, eps={eps}", + ) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels" +) +@pytest.mark.parametrize("batch_size", [1, 16, 128]) +@pytest.mark.parametrize("seq_len", [1, 32, 512]) +@pytest.mark.parametrize("hidden_size", [2048, 4096]) +def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int): + """ + Test RMS norm with 3D input tensors (batch, seq_len, hidden_size). + + Ensures that the batch-invariant RMS norm correctly handles multi-dimensional + inputs that are common in transformer models. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + eps = 1e-6 + + torch.manual_seed(42) + input_tensor = torch.randn( + batch_size, seq_len, hidden_size, dtype=dtype, device=device + ) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Use looser tolerance for bfloat16 + rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16 + + torch.testing.assert_close( + triton_output, + standard_output, + rtol=rtol, + atol=atol, + msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, " + f"seq_len={seq_len}, hidden_size={hidden_size}", + ) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels" +) +def test_rms_norm_numerical_stability(): + """ + Test RMS norm numerical stability with extreme values. + + Ensures that both implementations handle edge cases like very small or large + values without producing NaN or Inf. + """ + device = torch.device("cuda") + dtype = torch.float16 + eps = 1e-6 + hidden_size = 2048 + + # Test cases with extreme values + test_cases = [ + # Very small values + torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5, + # Very large values + torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4, + # Mixed small and large + torch.randn(4, hidden_size, dtype=dtype, device=device) * 100, + # Values near zero + torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6, + ] + + weight = torch.ones(hidden_size, dtype=dtype, device=device) + + for idx, input_tensor in enumerate(test_cases): + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Check for NaN or Inf + assert not torch.isnan(standard_output).any(), ( + f"Standard RMS norm produced NaN for test case {idx}" + ) + assert not torch.isinf(standard_output).any(), ( + f"Standard RMS norm produced Inf for test case {idx}" + ) + assert not torch.isnan(triton_output).any(), ( + f"Triton RMS norm produced NaN for test case {idx}" + ) + assert not torch.isinf(triton_output).any(), ( + f"Triton RMS norm produced Inf for test case {idx}" + ) + + # Compare outputs - very lenient for extreme values with float16 + torch.testing.assert_close( + triton_output, + standard_output, + rtol=2e-1, # 20% tolerance for extreme values + atol=2e-1, + msg=f"RMS norm mismatch for extreme value test case {idx}", + ) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels" +) +def test_rms_norm_formula(): + """ + Test that RMS norm follows the correct mathematical formula. + + Verifies: output = input / sqrt(mean(input^2) + eps) * weight + """ + device = torch.device("cuda") + dtype = torch.float32 # Use float32 for higher precision in formula check + eps = 1e-6 + hidden_size = 1024 + + torch.manual_seed(42) + input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Compute expected output using the formula + variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype) + expected_output = input_tensor * torch.rsqrt(variance + eps) * weight + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Compare against formula + torch.testing.assert_close( + triton_output, + expected_output, + rtol=1e-4, + atol=1e-4, + msg="Triton RMS norm doesn't match expected formula", + ) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels" +) +@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384]) +def test_rms_norm_different_hidden_sizes(hidden_size: int): + """ + Test RMS norm with various hidden sizes to ensure block size handling. + + The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it + correctly handles hidden sizes both smaller and larger than the block size. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + eps = 1e-6 + batch_size = 16 + + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Use looser tolerance for bfloat16 + rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16 + + torch.testing.assert_close( + triton_output, + standard_output, + rtol=rtol, + atol=atol, + msg=f"RMS norm mismatch for hidden_size={hidden_size}", + ) + + +@pytest.mark.skipif( + not current_platform.has_device_capability(90), + reason="Batch invariance tests only supported on Hopper (SM90)", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels" +) +def test_rms_norm_determinism(): + """ + Test that batch-invariant RMS norm produces deterministic results. + + Runs the same input through the kernel multiple times and verifies + identical outputs. + """ + device = torch.device("cuda") + dtype = torch.bfloat16 + eps = 1e-6 + hidden_size = 4096 + batch_size = 32 + + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Run multiple times + outputs = [] + for _ in range(5): + output = triton_rms_norm(input_tensor.clone(), weight, eps=eps) + outputs.append(output) + + # All outputs should be identical + reference = outputs[0] + for idx, output in enumerate(outputs[1:], start=1): + torch.testing.assert_close( + output, + reference, + rtol=0.0, + atol=0.0, + msg=f"RMS norm not deterministic: run {idx} differs from reference", + ) + + +if __name__ == "__main__": + # Run a quick smoke test + print("Running quick smoke test of RMS norm implementations...") + + device = torch.device("cuda") + batch_size = 8 + hidden_size = 4096 + dtype = torch.bfloat16 + eps = 1e-6 + + torch.manual_seed(42) + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + + # Standard implementation + rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device) + rms_norm_layer.weight.data = weight.clone() + standard_output = rms_norm_layer.forward_cuda(input_tensor) + + # Batch-invariant implementation + triton_output = triton_rms_norm(input_tensor, weight, eps=eps) + + # Compare + max_diff = (triton_output - standard_output).abs().max().item() + mean_diff = (triton_output - standard_output).abs().mean().item() + + print(f"Max difference: {max_diff:.6e}") + print(f"Mean difference: {mean_diff:.6e}") + print(f"Standard output sample: {standard_output[0, :5].tolist()}") + print(f"Triton output sample: {triton_output[0, :5].tolist()}") + + if max_diff < 1e-3: + print("✓ Smoke test passed!") + else: + print("✗ Smoke test failed - differences too large") diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index fc930e9b4f143..16e34c2711e9f 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -3,6 +3,7 @@ import hashlib import inspect +import os import pickle from unittest.mock import patch @@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str: ) file_contents = {} for filepath in files: - if filepath == "": + # Skip files that don't exist (e.g., , , etc.) + if not os.path.isfile(filepath): file_contents[filepath] = "" else: with open(filepath) as f: diff --git a/vllm/config/model.py b/vllm/config/model.py index 6e5757ba037d5..ebad9bfb9c904 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -20,6 +20,9 @@ from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, @@ -419,6 +422,10 @@ class ModelConfig: skip_mm_profiling: bool | None, video_pruning_rate: float | None, ) -> None: + # Enable batch invariance settings if requested + if vllm_kernel_override_batch_invariant(): + self.enforce_eager = True + # Set the default seed to 0 in V1. # NOTE(woosuk): In V0, we set the default seed to None because the # driver worker shares the same process as the user process, and thus diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 944a1e8666f4b..9b0634ad2ac9c 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -14,6 +14,9 @@ from typing_extensions import Self import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.platforms import current_platform from vllm.utils import cuda_device_count_stateless, get_open_ports_list @@ -560,7 +563,10 @@ class ParallelConfig: def _verify_args(self) -> Self: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase - from vllm.platforms import current_platform + + # Enable batch invariance settings if requested + if vllm_kernel_override_batch_invariant(): + self.disable_custom_all_reduce = True if ( self.distributed_executor_backend is not None diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 9e99fd01a9197..a3eef87b451fa 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -19,6 +19,9 @@ import torch.multiprocessing as mp import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.utils import cuda_device_count_stateless, update_environment_variables logger = init_logger(__name__) @@ -71,6 +74,9 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) is_symmetric_memory_enabled, ) + if vllm_kernel_override_batch_invariant(): + return False + if not is_symmetric_memory_enabled(): return False if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index 96f8e7b355352..f214c013bd3b4 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -9,6 +9,9 @@ from vllm.distributed.device_communicators.all_reduce_utils import ( SYMM_MEM_ALL_REDUCE_MAX_SIZES, ) from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.platforms import current_platform try: @@ -100,6 +103,8 @@ class SymmMemCommunicator: return self.force_multimem = force_multimem self.disabled = False + if vllm_kernel_override_batch_invariant(): + self.disabled = True def should_use_symm_mem(self, inp: torch.Tensor): if self.disabled: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 801c30dc94786..654857315b15c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1694,7 +1694,7 @@ class EngineArgs: ) -> None: """Set Default Arguments for V1 Engine.""" - # V1 always uses chunked prefills and prefix caching + # V1 uses chunked prefills and prefix caching by default # for non-pooling tasks. # For pooling tasks the default is False if model_config.runner_type != "pooling": diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 653fbef1cafe8..029605aed5028 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -395,7 +395,6 @@ def mean_dim( Tensor with mean values along specified dimension """ # Validate inputs - assert input.is_cuda, "Input must be a CUDA tensor" assert -input.ndim <= dim < input.ndim, ( f"Invalid dimension {dim} for tensor with {input.ndim} dimensions" ) @@ -470,6 +469,45 @@ def mm_batch_invariant(a, b): return matmul_persistent(a, b) +def matmul_batch_invariant(a, b, *, out=None): + # torch.matmul can handle various dimensions + # For 2D x 2D, it's the same as mm + if a.ndim == 2 and b.ndim == 2: + result = matmul_persistent(a, b) + if out is not None: + out.copy_(result) + return out + return result + elif a.ndim == 3 and b.ndim == 3: + # Handle batched case like bmm + return bmm_batch_invariant(a, b, out=out) + else: + raise ValueError( + f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, " + f"got shapes {a.shape} and {b.shape}" + ) + + +def bmm_batch_invariant(a, b, *, out=None): + # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N) + # Process each batch separately with our persistent kernel + if a.ndim == 3 and b.ndim == 3: + results = [] + for i in range(a.shape[0]): + results.append(matmul_persistent(a[i], b[i])) + result = torch.stack(results, dim=0) + + if out is not None: + out.copy_(result) + return out + return result + else: + raise ValueError( + f"bmm_batch_invariant expects 3D tensors, " + f"got shapes {a.shape} and {b.shape}" + ) + + def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) @@ -479,11 +517,24 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float): return log_softmax(input, dim=dim) +def softmax_batch_invariant(input, dim, dtype=None): + # Compute softmax in a deterministic way + # First subtract max for numerical stability (standard practice) + input_max = torch.amax(input, dim=dim, keepdim=True) + input = input - input_max + exp_x = torch.exp(input) + sum_exp_x = torch.sum(exp_x, dim=dim, keepdim=True) + return exp_x / sum_exp_x + + def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None): assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}" result = input.to(torch.float32) + if len(dim) == 0: + dim = [i for i in range(len(input.shape))] + # Sort dimensions to reduce from largest to smallest to handle shifting dims # during iterative reduction. sorted_dims = sorted([d % input.ndim for d in dim], reverse=True) @@ -500,8 +551,134 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = return result +@triton.jit +def _rms_norm_kernel( + input_ptr, + weight_ptr, + output_ptr, + input_row_stride, + output_row_stride, + n_cols, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + Compute RMS normalization along the last dimension of a 2D tensor. + RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight + Each block handles one row of the input tensor. + """ + row_idx = tl.program_id(0).to(tl.int64) + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # Step 1: Compute sum of squares in float32 to avoid overflow + sum_sq = tl.zeros([1], dtype=tl.float32) + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + # Convert to float32 for accumulation to prevent overflow + vals_f32 = vals.to(tl.float32) + sq_vals = vals_f32 * vals_f32 + sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0)) + + # Step 2: Compute RMS (root mean square) in float32 + mean_sq = sum_sq / n_cols + rms = tl.sqrt(mean_sq + eps) + inv_rms = 1.0 / rms + + # Step 3: Normalize and apply weight + for col_offset in range(0, n_cols, BLOCK_SIZE): + col_idx = col_offset + tl.arange(0, BLOCK_SIZE) + mask = col_idx < n_cols + vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0) + weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0) + # Compute in float32 then convert back to input dtype + vals_f32 = vals.to(tl.float32) + weight_f32 = weight.to(tl.float32) + output_f32 = vals_f32 * inv_rms * weight_f32 + output = output_f32.to(vals.dtype) + tl.store(output_row_start_ptr + col_idx, output, mask=mask) + + +def rms_norm( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Compute RMS normalization using Triton kernel. + + RMS Norm normalizes the input by the root mean square and scales by weight: + output = input / sqrt(mean(input^2) + eps) * weight + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + Tensor with RMS normalization applied along the last dimension + """ + assert weight.dim() == 1, "Weight must be 1-dimensional" + assert input.shape[-1] == weight.shape[0], ( + f"Input last dimension ({input.shape[-1]}) must match " + f"weight dimension ({weight.shape[0]})" + ) + + # Flatten all dimensions except the last one + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]) + input_2d = input_2d.contiguous() + weight = weight.contiguous() + + n_rows, n_cols = input_2d.shape + + output = torch.empty_like(input_2d) + BLOCK_SIZE = 1024 + grid = (n_rows,) + _rms_norm_kernel[grid]( + input_2d, + weight, + output, + input_2d.stride(0), + output.stride(0), + n_cols, + eps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output.reshape(original_shape) + + +def rms_norm_batch_invariant( + input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Batch-invariant wrapper for RMS normalization. + + This function provides a deterministic, batch-invariant implementation + of RMS normalization for use with the batch_invariant mode. + + Args: + input: Input tensor of shape (..., hidden_size) + weight: Weight tensor of shape (hidden_size,) + eps: Small constant for numerical stability + + Returns: + RMS normalized tensor + """ + return rms_norm(input, weight, eps=eps) + + +def linear_batch_invariant(input, weight, bias=None): + output = mm_batch_invariant(input, weight.t()) + if bias is not None: + output = output + bias + return output + + _batch_invariant_MODE = False _batch_invariant_LIB = None +_original_torch_bmm = None def is_batch_invariant_mode_enabled(): @@ -509,7 +686,7 @@ def is_batch_invariant_mode_enabled(): def enable_batch_invariant_mode(): - global _batch_invariant_MODE, _batch_invariant_LIB + global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm if _batch_invariant_MODE: return @@ -517,16 +694,28 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" ) + _batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") + # Also monkeypatch torch.bmm directly as a fallback + _original_torch_bmm = torch.bmm + torch.bmm = bmm_batch_invariant + def disable_batch_invariant_mode(): - global _batch_invariant_MODE, _batch_invariant_LIB + global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() + if _original_torch_bmm is not None: + torch.bmm = _original_torch_bmm + _original_torch_bmm = None _batch_invariant_MODE = False _batch_invariant_LIB = None @@ -563,17 +752,55 @@ def vllm_kernel_override_batch_invariant(): return is_overridden +def override_envs_for_invariance(): + curr_attn_backend = envs.VLLM_ATTENTION_BACKEND + supported_backends = [ + "FLASH_ATTN", # best supported backend + "FLEX_ATTENTION", + "FLASHINFER", + "FLASH_ATTN_MLA", + "TRITON_MLA", + # Not yet supported MLA backends + # "FLASHMLA", + # "FLASHINFER_MLA", + ] + if curr_attn_backend not in supported_backends: + warning = ( + "Forcibly updating attention backend to" + f" {supported_backends[0]} for batch_invariant. " + f" Supported backends: {supported_backends}." + ) + logger.warning_once(warning) + os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] + if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: + warning = ( + "You are using a decode-invariant form of batch invariance. " + "This will not be invariant between prefill and decode." + ) + logger.warning_once(warning) + os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" + + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + # NCCL determinism settings + os.environ["NCCL_LAUNCH_MODE"] = "GROUP" + os.environ["NCCL_COLLNET_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["NCCL_P2P_NET_DISABLE"] = "1" + os.environ["NCCL_MIN_NCHANNELS"] = "1" + os.environ["NCCL_MAX_NCHANNELS"] = "1" + os.environ["NCCL_PROTO"] = "Simple" + os.environ["NCCL_ALGO"] = "allreduce:tree" + os.environ["NCCL_NTHREADS"] = "1" + os.environ["NCCL_SOCKET_NTHREADS"] = "1" + + def init_batch_invariance(): # this will hit all the csrc overrides as well if vllm_kernel_override_batch_invariant(): - curr_attn_backend = envs.VLLM_ATTENTION_BACKEND - supported_backends = ["FLEX_ATTENTION", "FLASHINFER"] - if curr_attn_backend not in supported_backends: - warning = ( - "Forcibly updating attention backend to" - f" {supported_backends[0]} for batch_invariant. " - f" Supported backends: {supported_backends}." - ) - logger.warning_once(warning) - os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] + override_envs_for_invariance() enable_batch_invariant_mode() + + # Disable TF32 for batch invariance - it causes non-deterministic rounding + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 9f66e47dcb96c..256f4964b6548 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -15,6 +15,9 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, FusedMoEQuantConfig, @@ -837,6 +840,10 @@ def get_moe_configs( be picked and the associated configuration chosen to invoke the kernel. """ + # Avoid optimizing for the batch invariant case. Use default config + if vllm_kernel_override_batch_invariant(): + return None + # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None @@ -969,6 +976,15 @@ def get_default_config( dtype: str | None, block_shape: list[int] | None = None, ) -> dict[str, int]: + if vllm_kernel_override_batch_invariant(): + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + return config + if dtype == "fp8_w8a8" and block_shape is not None: # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # BLOCK_SIZE_K must be divisible by block_shape[1] @@ -1118,7 +1134,10 @@ def fused_topk_bias( scores_for_choice = scores.view( -1, n_routed_experts ) + e_score_correction_bias.unsqueeze(0) - topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1] + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_kernel_override_batch_invariant() + topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -1179,7 +1198,10 @@ def grouped_topk( group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + + # For batch invariance, use sorted=True to ensure deterministic expert selection + use_sorted = vllm_kernel_override_batch_invariant() + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] @@ -1192,11 +1214,13 @@ def grouped_topk( tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=use_sorted + ) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 135fbda2d540f..a689bc7be00f8 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -8,6 +8,10 @@ import torch.nn.functional as F import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.batch_invariant import ( + rms_norm_batch_invariant, + vllm_kernel_override_batch_invariant, +) from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -21,6 +25,8 @@ def rms_norm( ) -> torch.Tensor: from vllm import _custom_ops as ops + if vllm_kernel_override_batch_invariant(): + return rms_norm_batch_invariant(x, weight, variance_epsilon) out = torch.empty_like(x) ops.rms_norm( out, @@ -39,6 +45,10 @@ def fused_add_rms_norm( ) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops + if vllm_kernel_override_batch_invariant(): + return rms_norm_batch_invariant( + x + residual, weight, variance_epsilon + ), x + residual ops.fused_add_rms_norm( x, residual, diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 4c81162d7d2b9..34f05f2ee9624 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -160,6 +160,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp): k_pe, output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), ) + return self.o_proj(attn_out)[0] def forward_cuda(self, *args, **kwargs): diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5967ee9b6e3f3..4edb55d816cf3 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -14,6 +14,9 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, @@ -353,6 +356,8 @@ class Fp8LinearMethod(LinearMethodBase): # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False + if vllm_kernel_override_batch_invariant(): + self.use_marlin = False self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() @@ -534,6 +539,66 @@ class Fp8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: + # If batch invariant mode is enabled, dequantize and use BF16 compute + if vllm_kernel_override_batch_invariant(): + # Dequantize FP8 weights to BF16 + weight_fp8 = layer.weight.to(torch.bfloat16) + weight_scale = layer.weight_scale.to(torch.bfloat16) + + # Handle different quantization granularities + if self.block_quant: + # Block-wise quantization: + # - Weight is NOT transposed, shape is [N, K] (output_size, input_size) + # - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!) + assert self.weight_block_size is not None + block_n, block_k = self.weight_block_size # Note: order is [N, K] + + N, K = weight_fp8.shape + + # Scale is stored transposed: [num_blocks_k, num_blocks_n] + # We need to transpose it to [num_blocks_n, num_blocks_k] first + weight_scale = weight_scale.t() + + # Expand scale to match weight dimensions + # scale_expanded should have shape [N, K] + scale_expanded = weight_scale.repeat_interleave( + block_n, dim=0 + ).repeat_interleave(block_k, dim=1) + # Trim to exact weight size (in case of padding) + scale_expanded = scale_expanded[:N, :K] + weight_bf16 = weight_fp8 * scale_expanded + else: + # Per-tensor quantization: weight IS transposed to [K, N] + # scale should be scalar or [1] or per-output-channel [N] + if weight_scale.numel() == 1: + # Per-tensor: simple scalar multiplication + weight_bf16 = weight_fp8 * weight_scale + else: + # Multiple scales (fused modules like QKV) + # Try to infer correct broadcasting + # weight is [K, N], scale could be [num_logical_weights] + # Need to figure out how to broadcast - for now just try + # direct multiplication + if ( + weight_scale.dim() == 1 + and weight_scale.shape[0] == weight_fp8.shape[0] + ): + # Per-row scaling + weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1) + else: + # Fallback + weight_bf16 = weight_fp8 * weight_scale + + # For block quant, weight is [N, K], for per-tensor it's [K, N] + # F.linear expects weight to be [N, K], so: + if self.block_quant: + # Already in correct shape [N, K] + output = torch.nn.functional.linear(x, weight_bf16, bias) + else: + # Need to transpose back: [K, N] -> [N, K] + output = torch.nn.functional.linear(x, weight_bf16.t(), bias) + return output + if self.use_marlin: return apply_fp8_marlin_linear( input=x, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index fcba9b8e66c29..7f4040ca94223 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -216,6 +216,7 @@ class TransformerBlock(torch.nn.Module): else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attn(hidden_states, positions) + # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) output = self.mlp(hidden_states) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 087f995e0528b..6811860a34b07 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -31,10 +31,12 @@ if is_flash_attn_varlen_func_available(): get_scheduler_metadata, reshape_and_cache_flash, ) - from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.utils import cdiv from vllm.v1.attention.backends.utils import ( AttentionCGSupport, @@ -306,6 +308,9 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits + if vllm_kernel_override_batch_invariant(): + max_num_splits = 1 + def schedule( batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal ): @@ -478,6 +483,9 @@ class FlashAttentionImpl(AttentionImpl): self.attn_type = attn_type self.vllm_flash_attn_version = get_flash_attn_version() + # Cache the batch invariant result for use in forward passes + self.batch_invariant_enabled = vllm_kernel_override_batch_invariant() + if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device." @@ -810,6 +818,7 @@ class FlashAttentionImpl(AttentionImpl): q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape), + num_splits=1 if self.batch_invariant_enabled else 0, ) return output @@ -954,6 +963,7 @@ def cascade_attention( # s_aux is incorporated into prefix_lse inside the GPU kernel, # enabling its effect during the final attention merge. s_aux=s_aux, + num_splits=1 if vllm_kernel_override_batch_invariant() else 0, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -978,6 +988,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, + num_splits=1 if vllm_kernel_override_batch_invariant() else 0, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f7e6f12363ad8..1d4e3e4cfe227 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -211,6 +211,9 @@ from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.model_executor.layers.linear import ( ColumnParallelLinear, LinearBase, @@ -1187,6 +1190,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) x = aiter_triton_fp8_bmm( @@ -1279,6 +1283,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse + if vllm_kernel_override_batch_invariant(): + kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( q=q, @@ -1841,9 +1847,11 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): if has_decode: assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) + # Convert from (B, N, P) to (N, B, P) decode_q_nope = decode_q_nope.transpose(0, 1) @@ -1868,17 +1876,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # Pads the head_dim if necessary (for the underlying kernel) N, B, P = decode_q_nope.shape _, _, L = self.W_UK_T.shape + if self.q_pad_num_heads is not None: decode_ql_nope = decode_q_nope.new_empty( (self.q_pad_num_heads, B, L) ) decode_ql_nope.resize_((N, B, L)) - else: decode_ql_nope = decode_q_nope.new_empty((N, B, L)) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) + # Convert from (N, B, L) to (B, N, L) decode_ql_nope = decode_ql_nope.transpose(0, 1) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 446f1c4f1f961..3e404d50ee7c3 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -18,6 +18,9 @@ from vllm.attention.utils.fa_utils import ( ) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -107,6 +110,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # pre-allocated during capture. self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH + if vllm_kernel_override_batch_invariant(): + self.max_num_splits = 1 + def _schedule_decode( self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal ): @@ -175,7 +181,10 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - return FlashAttnMLADecodeMetadata( + if vllm_kernel_override_batch_invariant(): + max_num_splits = 1 + + metadata = FlashAttnMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, query_start_loc=query_start_loc_device, @@ -185,6 +194,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] max_num_splits=max_num_splits, dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) + return metadata class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b15c09294c6b7..fc8fb34afb18a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -14,6 +14,9 @@ from vllm.attention.ops.flashmla import ( ) from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -223,19 +226,50 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): if type(q) is tuple: q = torch.cat(q, dim=-1) + # mypy assertion: q is now always a tensor assert isinstance(q, torch.Tensor) num_decodes = attn_metadata.num_decodes q = reshape_query_for_spec_decode(q, num_decodes) + tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata + num_splits = attn_metadata.decode.num_splits + if vllm_kernel_override_batch_invariant(): + device = q.device + dtype = torch.int32 + + B = q.shape[0] + # block_table shape: [batch_size, max_num_blocks_per_seq] + # The number of blocks per sequence is in the second dimension + topk = attn_metadata.decode.block_table.shape[-1] + B_TOPK = 64 + assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}" + end_block_idx = topk // B_TOPK + + # Single partition => num_sm_parts = 1 + # TileSchedulerMetaDataSize = 8, layout: + # [begin_idx, begin_block_idx, end_idx, end_block_idx, + # begin_n_split_idx, _, _, _] + tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device) + tile_scheduler_metadata[0, 0] = 0 # begin_idx + tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx + tile_scheduler_metadata[0, 2] = B - 1 # end_idx + tile_scheduler_metadata[0, 3] = end_block_idx + tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx + # fields [5..7] stay 0 + + # Non-split path ignores num_splits, but the API requires it: + # zeros of length B+1 + num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) + o, lse = flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=attn_metadata.decode.block_table, cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata.decode.tile_scheduler_metadata, - num_splits=attn_metadata.decode.num_splits, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, softmax_scale=self.scale, causal=True, descale_q=layer._q_scale.reshape(1), diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index e2df0179d99a8..d3524020bc7fa 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -13,6 +13,9 @@ from vllm.attention.backends.abstract import ( from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_kernel_override_batch_invariant, +) from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backends.mla.common import ( @@ -158,7 +161,9 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): B, q_num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device ) lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) - num_kv_splits = 4 # TODO: heuristic + + # For batch invariance, use only 1 split to ensure deterministic reduction + num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4 # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9e394dbb592ec..7e72ce937be41 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -231,9 +231,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) - from vllm.model_executor.layers.batch_invariant import init_batch_invariance - - init_batch_invariance() model_config = self.model_config cache_config = self.cache_config diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0e9ab3f9148b9..00dc7682c973b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -764,6 +764,9 @@ def init_worker_distributed_environment( ) -> None: """Initialize the distributed environment.""" parallel_config = vllm_config.parallel_config + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + + init_batch_invariance() set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(