diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 4f31e5afa1ac9..fde0826779eb1 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -6,7 +6,7 @@ from typing import Any import torch -import vllm.envs as envs +from vllm.attention.backends.registry import AttentionBackendEnum from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool: return VLLM_BATCH_INVARIANT -def override_envs_for_invariance(): - curr_attn_backend = envs.VLLM_ATTENTION_BACKEND +def override_envs_for_invariance( + attention_backend: AttentionBackendEnum | None, +): supported_backends = [ - "FLASH_ATTN", # best supported backend - "FLASHINFER", - "FLASH_ATTN_MLA", - "TRITON_MLA", + AttentionBackendEnum.FLASH_ATTN, # best supported backend + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.FLASH_ATTN_MLA, + AttentionBackendEnum.TRITON_MLA, # Not yet supported MLA backends - # "FLASHMLA", - # "FLEX_ATTENTION", # IMA issue even if we disable batch invariance - # "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967 + # AttentionBackendEnum.FLASHMLA, + # AttentionBackendEnum.FLEX_ATTENTION, # IMA issue + # AttentionBackendEnum.FLASHINFER_MLA, # PR #28967 ] - if curr_attn_backend not in supported_backends: + if attention_backend not in supported_backends: + supported_names = [b.name for b in supported_backends] + backend_name = attention_backend.name if attention_backend else None error = ( "VLLM batch_invariant mode requires an attention backend in " - f"{supported_backends}, but got '{curr_attn_backend}'. " - "Please set the 'VLLM_ATTENTION_BACKEND' environment variable " - "to one of the supported backends before enabling batch_invariant." + f"{supported_names}, but got '{backend_name}'. " + "Please use --attention-backend or attention_config to set " + "one of the supported backends before enabling batch_invariant." ) raise RuntimeError(error) - if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]: + if attention_backend != supported_backends[0]: warning = ( "You are using a decode-invariant form of batch invariance. " "This will not be invariant between prefill and decode." @@ -1050,10 +1053,12 @@ def override_envs_for_invariance(): os.environ["VLLM_USE_AOT_COMPILE"] = "0" -def init_batch_invariance(): +def init_batch_invariance( + attention_backend: AttentionBackendEnum | None, +): # this will hit all the csrc overrides as well if vllm_is_batch_invariant(): - override_envs_for_invariance() + override_envs_for_invariance(attention_backend) enable_batch_invariant_mode() # Disable TF32 for batch invariance - it causes non-deterministic rounding diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 21a8564f83c40..1e13650cd083e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -931,10 +931,11 @@ def init_worker_distributed_environment( backend: str = "nccl", ) -> None: """Initialize the distributed environment.""" + attention_config = vllm_config.attention_config parallel_config = vllm_config.parallel_config from vllm.model_executor.layers.batch_invariant import init_batch_invariance - init_batch_invariance() + init_batch_invariance(attention_config.backend) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_method = distributed_init_method or "env://"