Revert "Add batch invariant kernel override for FlashInfer backend [2/n]" (#26220)

This commit is contained in:
Cyrus Leung 2025-10-04 17:45:08 +08:00 committed by GitHub
parent 7d6b03381e
commit 1838cd4860
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 84 deletions

View File

@ -76,21 +76,18 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
seed. seed.
- Keep max_tokens and max_model_len bounded for speed and memory use. - Keep max_tokens and max_model_len bounded for speed and memory use.
""" """
seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(12345)
random.seed(seed)
# Allow overrides from environment (useful for CI tuning) # Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism # "facebook/opt-125m" is too small, doesn't reliably test determinism
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
max_random_prompt = int(os.getenv("VLLM_MAX_PROMPT", "2048"))
assert max_batch_size >= 2, "Batch size should be >= 2 to mix needle."
# Keep GPU memory usage low to avoid startup allocation failures. # Keep GPU memory usage low to avoid startup allocation failures.
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.4")) gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "5120")) max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4")) swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
# Sampling parameters: longer outputs with a more random-sounding # Sampling parameters: longer outputs with a more random-sounding
@ -114,7 +111,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with bs=1 behavior # Engine with bs=1 behavior
llm_bs1 = LLM_with_max_seqs( llm_bs1 = LLM_with_max_seqs(
model=model, model=model,
max_num_seqs=128, max_num_seqs=1,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
swap_space=swap_space_gb, swap_space=swap_space_gb,
@ -129,7 +126,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
# Engine with larger batch limit (e.g., 64) # Engine with larger batch limit (e.g., 64)
llm_bsN = LLM_with_max_seqs( llm_bsN = LLM_with_max_seqs(
model=model, model=model,
max_num_seqs=128, max_num_seqs=batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len, max_model_len=max_model_len,
swap_space=swap_space_gb, swap_space=swap_space_gb,
@ -138,17 +135,15 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
mismatches = 0 mismatches = 0
for trial in range(num_trials): for trial in range(num_trials):
# Create a batch of size `max_batch_size` and insert the needle at # Create a batch of size `batch_size` and insert the needle at
# a random index # a random index
prompts: list[str] = [] prompts: list[str] = []
batch_size = random.randint(max_batch_size // 2, max_batch_size)
needle_pos = random.randint(0, batch_size - 1) needle_pos = random.randint(0, batch_size - 1)
for i in range(batch_size): for i in range(batch_size):
if i == needle_pos: if i == needle_pos:
prompts.append(needle_prompt) prompts.append(needle_prompt)
else: else:
prompts.append( prompts.append(_random_prompt())
_random_prompt(min_random_prompt, max_random_prompt))
# Generate with the larger-batch engine # Generate with the larger-batch engine
outputs = llm_bsN.generate(prompts, sampling) outputs = llm_bsN.generate(prompts, sampling)
@ -159,19 +154,17 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
text = needle_output.outputs[0].text text = needle_output.outputs[0].text
if text != baseline_text: if text != baseline_text:
print(
f"{text}\n\n== Not the same as ==\n\n{baseline_text}\n\n")
mismatches += 1 mismatches += 1
passes = num_trials - mismatches passes = num_trials - mismatches
# Dump how many passed vs failed # Dump how many passed vs failed
print(f"[determinism] total={num_trials}, passed={passes}, " print(f"[determinism] total={num_trials}, passed={passes}, "
f"failed={mismatches}, max_batch_size={max_batch_size}") f"failed={mismatches}, batch_size={batch_size}")
if mismatches > 0: if mismatches > 0:
pytest.fail( pytest.fail(
f"Nondeterministic outputs detected: {mismatches} failed out " f"Nondeterministic outputs detected: {mismatches} failed out "
f"of {num_trials} trials (max_batch_size={max_batch_size}).") f"of {num_trials} trials (batch_size={batch_size}).")
finally: finally:
# Ensure engines are shutdown to free GPU/VRAM across test sessions # Ensure engines are shutdown to free GPU/VRAM across test sessions
@ -203,14 +196,9 @@ def _extract_step_logprobs(request_output):
not torch.cuda.is_available(), not torch.cuda.is_available(),
reason="Requires CUDA to match production inference path.", reason="Requires CUDA to match production inference path.",
) )
@pytest.mark.parametrize("backend", ["FLEX_ATTENTION", "FLASHINFER"]) def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
backend = os.getenv("VLLM_ATTENTION_BACKEND", backend) #model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
os.environ["VLLM_ATTENTION_BACKEND"] = backend
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B") model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
@ -224,15 +212,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
prompts = [ prompts = [
"The capital of France is", "The capital of France is",
"The capital of Germany is", "The capital of Germany is",
_random_prompt(10, 1024),
_random_prompt(10, 1024),
_random_prompt(10, 1024),
_random_prompt(10, 1024),
_random_prompt(10, 1024),
] ]
sp = SamplingParams( sp = SamplingParams(
temperature=0.6, temperature=0.0,
top_p=1.0, top_p=1.0,
max_tokens=8, max_tokens=8,
# Seed shouldn't matter at temperature=0, but keeping it stable anyway. # Seed shouldn't matter at temperature=0, but keeping it stable anyway.
@ -251,25 +234,25 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
"enable logprobs return to run this test.") "enable logprobs return to run this test.")
bs1_logprobs_per_prompt.append(step_logprobs) bs1_logprobs_per_prompt.append(step_logprobs)
# BS=N: run prompts in a batch and collect logprobs per step for each # BS=2: run prompts in a batch and collect logprobs per step for each
# prompt. # prompt.
outs_batched = llm.generate(prompts, sp, use_tqdm=False) outs_batched = llm.generate(prompts, sp, use_tqdm=False)
assert len(outs_batched) == len(prompts) assert len(outs_batched) == len(prompts)
bsN_logprobs_per_prompt = [] bs2_logprobs_per_prompt = []
for o in outs_batched: for o in outs_batched:
step_logprobs = _extract_step_logprobs(o) step_logprobs = _extract_step_logprobs(o)
if step_logprobs is None: if step_logprobs is None:
pytest.skip("Logits are not available on RequestOutput; " pytest.skip("Logits are not available on RequestOutput; "
"enable logprobs return to run this test.") "enable logprobs return to run this test.")
bsN_logprobs_per_prompt.append(step_logprobs) bs2_logprobs_per_prompt.append(step_logprobs)
# Compare step-by-step logprobs for each prompt between BS=1 and BS=N runs. # Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
for i, (logprobs_bs1, logprobs_bsN) in enumerate( for i, (logprobs_bs1, logprobs_bs2) in enumerate(
zip(bs1_logprobs_per_prompt, bsN_logprobs_per_prompt)): zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
assert len(logprobs_bs1) == len(logprobs_bsN), ( assert len(logprobs_bs1) == len(logprobs_bs2), (
f"Different number of generation steps for prompt index {i}: " f"Different number of generation steps for prompt index {i}: "
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)") f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)): for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
assert a.shape == b.shape, ( assert a.shape == b.shape, (
f"Logits shape mismatch at prompt {i}, step {t}: " f"Logits shape mismatch at prompt {i}, step {t}: "
f"{a.shape} vs {b.shape}") f"{a.shape} vs {b.shape}")

View File

@ -8,12 +8,8 @@ from typing import Any, Union
import torch import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any, def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
args: dict[str, Any]) -> dict[str, Any]: args: dict[str, Any]) -> dict[str, Any]:
@ -561,12 +557,5 @@ def vllm_kernel_override_batch_invariant():
def init_batch_invariance(): def init_batch_invariance():
# this will hit all the csrc overrides as well # this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant(): if vllm_kernel_override_batch_invariant():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = "Forcibly updating attention backend to" \
f" {supported_backends[0]} for batch_invariant. " \
f" Supported backends: {supported_backends}."
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
enable_batch_invariant_mode() enable_batch_invariant_mode()

View File

@ -20,8 +20,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant) QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -44,7 +42,6 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT = 2048 * 1024 * 1024
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
@ -266,15 +263,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._prefill_wrapper = None # Wrapper for prefill/append self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape) self._decode_wrapper = None # Wrapper for decode (general shape)
if vllm_kernel_override_batch_invariant():
self.decode_fixed_split_size = 2048
self.prefill_fixed_split_size = 4096
self.disable_split_kv = True
else:
self.decode_fixed_split_size = -1
self.prefill_fixed_split_size = -1
self.disable_split_kv = False
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
max_num_pages_per_req = cdiv(self.model_config.max_model_len, max_num_pages_per_req = cdiv(self.model_config.max_model_len,
self.kv_cache_spec.block_size) self.kv_cache_spec.block_size)
@ -368,12 +356,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE self._workspace_buffer = torch.zeros(
if vllm_kernel_override_batch_invariant(): FLASHINFER_WORKSPACE_BUFFER_SIZE,
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT dtype=torch.uint8,
self._workspace_buffer = torch.zeros(buffer_size, device=self.device)
dtype=torch.uint8,
device=self.device)
return self._workspace_buffer return self._workspace_buffer
def _get_prefill_wrapper(self): def _get_prefill_wrapper(self):
@ -629,8 +615,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype, kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.prefill_fixed_split_size,
disable_split_kv=self.disable_split_kv,
) )
else: else:
attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to( attn_metadata.qo_indptr_gpu = qo_indptr_cpu.to(
@ -684,8 +668,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
logits_soft_cap=self.logits_soft_cap, logits_soft_cap=self.logits_soft_cap,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
kv_data_type=self.kv_cache_dtype, kv_data_type=self.kv_cache_dtype,
fixed_split_size=self.decode_fixed_split_size,
disable_split_kv=self.disable_split_kv,
) )
return attn_metadata return attn_metadata
@ -1066,8 +1048,6 @@ def fast_plan_decode(
rope_scale: Optional[float] = None, rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
non_blocking: bool = True, non_blocking: bool = True,
fixed_split_size: int = -1,
disable_split_kv: bool = False,
) -> None: ) -> None:
""" """
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for
@ -1105,10 +1085,6 @@ def fast_plan_decode(
rope_scale, rope_scale,
rope_theta, rope_theta,
non_blocking, non_blocking,
None, # block_tables
None, # seq_lens
fixed_split_size,
disable_split_kv,
) )
self.vllm_first_call = False self.vllm_first_call = False
return return
@ -1154,7 +1130,7 @@ def fast_plan_decode(
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
try: try:
# Make sure we pass exactly 18 arguments for tensor core version # Make sure we pass exactly 15 arguments for tensor core version
self._plan_info = self._cached_module.plan( self._plan_info = self._cached_module.plan(
self._float_workspace_buffer, self._float_workspace_buffer,
self._int_workspace_buffer, self._int_workspace_buffer,
@ -1171,9 +1147,6 @@ def fast_plan_decode(
head_dim, head_dim,
head_dim, head_dim,
False, # causal False, # causal
window_left,
fixed_split_size,
disable_split_kv,
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e raise RuntimeError(f"Error in tensor core plan: {e}") from e