From 699bca76c00b81ba6c7ead38fed01712f5f56aa1 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:49:01 -0500 Subject: [PATCH] [UX] Raise error for attn backend of batch invariant (#29348) Signed-off-by: yewentao256 --- vllm/model_executor/layers/batch_invariant.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 8b33727f05fbc..be7f673e5618f 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -812,19 +812,19 @@ def override_envs_for_invariance(): # "TRITON_MLA", ] if curr_attn_backend not in supported_backends: - warning = ( - "Forcibly updating attention backend to" - f" {supported_backends[0]} for batch_invariant. " - f" Supported backends: {supported_backends}." + error = ( + "VLLM batch_invariant mode requires an attention backend in " + f"{supported_backends}, but got '{curr_attn_backend}'. " + "Please set the 'VLLM_ATTENTION_BACKEND' environment variable " + "to one of the supported backends before enabling batch_invariant." ) - logger.warning_once(warning) - os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0] + raise RuntimeError(error) if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: warning = ( "You are using a decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." ) - logger.warning_once(warning) + logger.warning_once(warning, scope="local") os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"