diff --git a/vllm/config/model.py b/vllm/config/model.py index 092c67e7bed8..082f90653f5a 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -20,9 +20,6 @@ from vllm.config.pooler import PoolerConfig from vllm.config.scheduler import RunnerType from vllm.config.utils import assert_hashable, config, getattr_iter from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, @@ -436,10 +433,6 @@ class ModelConfig: skip_mm_profiling: bool | None, video_pruning_rate: float | None, ) -> None: - # Enable batch invariance settings if requested - if vllm_is_batch_invariant(): - self.enforce_eager = True - # Set the default seed to 0 in V1. # NOTE(woosuk): In V1, we use separate processes for workers (unless # VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here diff --git a/vllm/envs.py b/vllm/envs.py index 2744335ed3d3..21237c70a45e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -251,6 +251,9 @@ def disable_compile_cache() -> bool: def use_aot_compile() -> bool: + from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, + ) from vllm.utils.torch_utils import is_torch_equal_or_newer default_value = ( @@ -259,7 +262,10 @@ def use_aot_compile() -> bool: else "0" ) - return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" + return ( + not vllm_is_batch_invariant() + and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" + ) def env_with_choices( diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 5706786bccb1..39e77b935d3d 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -11,6 +11,7 @@ import torch import vllm.envs as envs from vllm.logger import init_logger from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import is_torch_equal_or_newer logger = init_logger(__name__) @@ -716,6 +717,10 @@ def linear_batch_invariant(input, weight, bias=None): _batch_invariant_MODE = False _batch_invariant_LIB = None _original_torch_bmm = None +_original_fp16_reduction_precision = None +_original_bf16_reduction_precision = None +_original_cublas_workspace_cfg = None +_original_cublaslt_workspace_size = None def is_batch_invariant_mode_enabled(): @@ -724,6 +729,8 @@ def is_batch_invariant_mode_enabled(): def enable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm + global _original_fp16_reduction_precision, _original_bf16_reduction_precision + global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size if _batch_invariant_MODE: return @@ -745,14 +752,75 @@ def enable_batch_invariant_mode(): _original_torch_bmm = torch.bmm torch.bmm = bmm_batch_invariant + _original_bf16_reduction_precision = ( + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + ) + _original_fp16_reduction_precision = ( + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + ) + + reduced_precision_val = ( + (False, False) if is_torch_equal_or_newer("2.10.0.dev") else False + ) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + reduced_precision_val + ) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + reduced_precision_val + ) + torch.backends.cuda.preferred_blas_library(backend="cublaslt") + + if not is_torch_equal_or_newer("2.10.0.dev"): + _original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None) + _original_cublaslt_workspace_size = os.environ.get( + "CUBLASLT_WORKSPACE_SIZE", None + ) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1" + def disable_batch_invariant_mode(): global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm + global _original_fp16_reduction_precision, _original_bf16_reduction_precision + global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size + if not _batch_invariant_MODE: + return + if _batch_invariant_LIB is not None: _batch_invariant_LIB._destroy() if _original_torch_bmm is not None: torch.bmm = _original_torch_bmm _original_torch_bmm = None + + if _original_bf16_reduction_precision is not None: + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( + _original_bf16_reduction_precision + ) + _original_bf16_reduction_precision = None + if _original_fp16_reduction_precision is not None: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = ( + _original_fp16_reduction_precision + ) + _original_fp16_reduction_precision = None + + torch.backends.cuda.preferred_blas_library(backend="default") + + if not is_torch_equal_or_newer("2.10.0.dev"): + # Set cublas env vars to previous results. If previous results are None, + # that means the env vars were not set, so we should remove them. + if _original_cublas_workspace_cfg: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg + elif "CUBLAS_WORKSPACE_CONFIG" in os.environ: + del os.environ["CUBLAS_WORKSPACE_CONFIG"] + + if _original_cublaslt_workspace_size: + os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size + elif "CUBLASLT_WORKSPACE_SIZE" in os.environ: + del os.environ["CUBLASLT_WORKSPACE_SIZE"] + + _original_cublas_workspace_cfg = None + _original_cublaslt_workspace_size = None + _batch_invariant_MODE = False _batch_invariant_LIB = None @@ -831,6 +899,9 @@ def override_envs_for_invariance(): os.environ["NCCL_NTHREADS"] = "1" os.environ["NCCL_SOCKET_NTHREADS"] = "1" + # torch.compile settings + os.environ["VLLM_USE_AOT_COMPILE"] = "0" + def init_batch_invariance(): # this will hit all the csrc overrides as well diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e5681cb85625..f82eccb88ce0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -363,6 +363,7 @@ class Fp8LinearMethod(LinearMethodBase): self.use_marlin = False self.use_aiter_and_is_supported = check_aiter_fp8_linear_support() + self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None @@ -545,8 +546,10 @@ class Fp8LinearMethod(LinearMethodBase): # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): + # Call is_deep_gemm_supported() ahead of time for torch.compile + # dynamo has trouble tracing through if self.block_quant and should_use_deepgemm_for_fp8_linear( - torch.bfloat16, layer.weight, None + torch.bfloat16, layer.weight, self.use_deep_gemm ): # use group quant consistent with block size across K assert self.act_q_group_shape is not None