diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 746a543ab827d..7920d117de5e0 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib import os -from collections import namedtuple from collections.abc import Callable from functools import cache from typing import Any @@ -725,10 +723,6 @@ _original_cublas_workspace_cfg = None _original_cublaslt_workspace_size = None -def is_batch_invariant_mode_enabled(): - return _batch_invariant_MODE - - def enable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm global _original_fp16_reduction_precision, _original_bf16_reduction_precision @@ -791,73 +785,6 @@ def enable_batch_invariant_mode(): torch.backends.cuda.preferred_blas_library(backend="cublaslt") -def disable_batch_invariant_mode(): - global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm - global _original_fp16_reduction_precision, _original_bf16_reduction_precision - global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size - if not _batch_invariant_MODE: - return - - if _batch_invariant_LIB is not None: - _batch_invariant_LIB._destroy() - if _original_torch_bmm is not None: - torch.bmm = _original_torch_bmm - _original_torch_bmm = None - - if _original_bf16_reduction_precision is not None: - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( - _original_bf16_reduction_precision - ) - _original_bf16_reduction_precision = None - if _original_fp16_reduction_precision is not None: - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( - _original_fp16_reduction_precision - ) - _original_fp16_reduction_precision = None - - torch.backends.cuda.preferred_blas_library(backend="default") - - if not is_torch_equal_or_newer("2.10.0.dev"): - # Set cublas env vars to previous results. If previous results are None, - # that means the env vars were not set, so we should remove them. - if _original_cublas_workspace_cfg: - os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg - elif "CUBLAS_WORKSPACE_CONFIG" in os.environ: - del os.environ["CUBLAS_WORKSPACE_CONFIG"] - - if _original_cublaslt_workspace_size: - os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size - elif "CUBLASLT_WORKSPACE_SIZE" in os.environ: - del os.environ["CUBLASLT_WORKSPACE_SIZE"] - - _original_cublas_workspace_cfg = None - _original_cublaslt_workspace_size = None - - _batch_invariant_MODE = False - _batch_invariant_LIB = None - - -@contextlib.contextmanager -def set_batch_invariant_mode(enabled: bool = True): - global _batch_invariant_MODE, _batch_invariant_LIB - old_data = (_batch_invariant_MODE, _batch_invariant_LIB) - if enabled: - enable_batch_invariant_mode() - else: - disable_batch_invariant_mode() - yield - if _batch_invariant_LIB is not None: - _batch_invariant_LIB._destroy() - _batch_invariant_MODE, _batch_invariant_LIB = old_data - - -AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"]) - - -def get_batch_invariant_attention_block_size() -> AttentionBlockSize: - return AttentionBlockSize(block_m=16, block_n=16) - - @cache def vllm_is_batch_invariant(): env_key = "VLLM_BATCH_INVARIANT"