[Feature] Batch invariant torch.compile (#27660)

Signed-off-by: PaulZhang12 <paulzhan@fb.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Paul Zhang 2025-10-30 16:11:29 -04:00 committed by GitHub
parent 4b68c4a55b
commit e7acb20076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 9 deletions

View File

@ -20,9 +20,6 @@ from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat,
@ -436,10 +433,6 @@ class ModelConfig:
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
) -> None:
# Enable batch invariance settings if requested
if vllm_is_batch_invariant():
self.enforce_eager = True
# Set the default seed to 0 in V1.
# NOTE(woosuk): In V1, we use separate processes for workers (unless
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here

View File

@ -251,6 +251,9 @@ def disable_compile_cache() -> bool:
def use_aot_compile() -> bool:
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
@ -259,7 +262,10 @@ def use_aot_compile() -> bool:
else "0"
)
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
return (
not vllm_is_batch_invariant()
and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
)
def env_with_choices(

View File

@ -11,6 +11,7 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
@ -716,6 +717,10 @@ def linear_batch_invariant(input, weight, bias=None):
_batch_invariant_MODE = False
_batch_invariant_LIB = None
_original_torch_bmm = None
_original_fp16_reduction_precision = None
_original_bf16_reduction_precision = None
_original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None
def is_batch_invariant_mode_enabled():
@ -724,6 +729,8 @@ def is_batch_invariant_mode_enabled():
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
if _batch_invariant_MODE:
return
@ -745,14 +752,75 @@ def enable_batch_invariant_mode():
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
_original_bf16_reduction_precision = (
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
)
_original_fp16_reduction_precision = (
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
)
reduced_precision_val = (
(False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
)
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
reduced_precision_val
)
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
reduced_precision_val
)
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
if not is_torch_equal_or_newer("2.10.0.dev"):
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
if not _batch_invariant_MODE:
return
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
if _original_torch_bmm is not None:
torch.bmm = _original_torch_bmm
_original_torch_bmm = None
if _original_bf16_reduction_precision is not None:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
_original_bf16_reduction_precision
)
_original_bf16_reduction_precision = None
if _original_fp16_reduction_precision is not None:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
_original_fp16_reduction_precision
)
_original_fp16_reduction_precision = None
torch.backends.cuda.preferred_blas_library(backend="default")
if not is_torch_equal_or_newer("2.10.0.dev"):
# Set cublas env vars to previous results. If previous results are None,
# that means the env vars were not set, so we should remove them.
if _original_cublas_workspace_cfg:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
if _original_cublaslt_workspace_size:
os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
del os.environ["CUBLASLT_WORKSPACE_SIZE"]
_original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None
_batch_invariant_MODE = False
_batch_invariant_LIB = None
@ -831,6 +899,9 @@ def override_envs_for_invariance():
os.environ["NCCL_NTHREADS"] = "1"
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
# torch.compile settings
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
def init_batch_invariance():
# this will hit all the csrc overrides as well

View File

@ -363,6 +363,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
@ -545,8 +546,10 @@ class Fp8LinearMethod(LinearMethodBase):
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
# Call is_deep_gemm_supported() ahead of time for torch.compile
# dynamo has trouble tracing through
if self.block_quant and should_use_deepgemm_for_fp8_linear(
torch.bfloat16, layer.weight, None
torch.bfloat16, layer.weight, self.use_deep_gemm
):
# use group quant consistent with block size across K
assert self.act_q_group_shape is not None