[Feature] Extend batch invariant torch.compile to B200 (#27856)

Signed-off-by: PaulZhang12 <paulzhan@fb.com>
This commit is contained in:
Paul Zhang 2025-11-05 13:04:49 -05:00 committed by GitHub
parent 40db194446
commit faedbb4d4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 17 deletions

View File

@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
model=model, model=model,
max_num_seqs=1, max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enforce_eager=True,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
max_model_len=2048, max_model_len=2048,
dtype="bfloat16", dtype="bfloat16",
@ -998,7 +997,6 @@ def LLM_with_max_seqs(
dtype="bfloat16", dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=True,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import functools
import os import os
from collections import namedtuple from collections import namedtuple
from collections.abc import Callable from collections.abc import Callable
@ -11,6 +10,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True _batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA") # Batch invariant matmuls are no longer needed after cublas overrides
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA") if not is_torch_equal_or_newer("2.10.0.dev"):
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA") if current_platform.is_device_capability(100):
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA") # For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg = os.environ.get(
"CUBLAS_WORKSPACE_CONFIG", None
)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
_batch_invariant_LIB.impl( _batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA" "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
) )
@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA") _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback # Also monkeypatch torch.bmm directly as a fallback
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_original_torch_bmm = torch.bmm _original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant torch.bmm = bmm_batch_invariant
@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
) )
torch.backends.cuda.preferred_blas_library(backend="cublaslt") torch.backends.cuda.preferred_blas_library(backend="cublaslt")
if not is_torch_equal_or_newer("2.10.0.dev"):
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
def disable_batch_invariant_mode(): def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16) return AttentionBlockSize(block_m=16, block_n=16)
@functools.cache
def vllm_is_batch_invariant(): def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT" env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False is_overridden = False

View File

@ -19,6 +19,9 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used, return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used. return `False` if TRTLLM attention is forced to be not used.
""" """
if vllm_is_batch_invariant():
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
return False
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)