diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index fde0826779eb1..1058270889b29 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -933,30 +933,26 @@ def enable_batch_invariant_mode(): _batch_invariant_MODE = True _batch_invariant_LIB = torch.library.Library("aten", "IMPL") - # Batch invariant matmuls are no longer needed after cublas overrides - if not is_torch_equal_or_newer("2.10.0.dev"): - if ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(80) - or current_platform.is_device_capability(89) - ): - # For PyTorch 2.9, B200 uses GEMV for bs=1 - # Requires https://github.com/pytorch/pytorch/pull/166735 - _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") - _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") - else: - # Only source of batch invariance for Hopper is split-k, can disable through - # cuBLAS workspace config - _original_cublas_workspace_cfg = os.environ.get( - "CUBLAS_WORKSPACE_CONFIG", None - ) - _original_cublaslt_workspace_size = os.environ.get( - "CUBLASLT_WORKSPACE_SIZE", None - ) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1" + if ( + current_platform.is_device_capability_family(100) + or current_platform.is_device_capability(80) + or current_platform.is_device_capability(89) + ): + # For PyTorch 2.9, B200 uses GEMV for bs=1 + # Requires https://github.com/pytorch/pytorch/pull/166735 + _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") + _batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") + else: + # Only source of batch invariance for Hopper is split-k, can disable through + # cuBLAS workspace config + _original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None) + _original_cublaslt_workspace_size = os.environ.get( + "CUBLASLT_WORKSPACE_SIZE", None + ) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1" _batch_invariant_LIB.impl( "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"