[Feature] Extend batch invariant torch.compile to B200 (#27856)

Signed-off-by: PaulZhang12 <paulzhan@fb.com>
This commit is contained in:
Paul Zhang 2025-11-05 13:04:49 -05:00 committed by GitHub
parent 40db194446
commit faedbb4d4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 17 deletions

View File

@ -456,7 +456,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
model=model,
max_num_seqs=1,
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enforce_eager=True,
gpu_memory_utilization=0.9,
max_model_len=2048,
dtype="bfloat16",
@ -998,7 +997,6 @@ def LLM_with_max_seqs(
dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False,
enforce_eager=True,
# Enable for MOE models
# enable_expert_parallel=True,
)

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import functools
import os
from collections import namedtuple
from collections.abc import Callable
@ -11,6 +10,7 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_torch_equal_or_newer
@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
_batch_invariant_MODE = True
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
# Batch invariant matmuls are no longer needed after cublas overrides
if not is_torch_equal_or_newer("2.10.0.dev"):
if current_platform.is_device_capability(100):
# For PyTorch 2.9, B200 uses GEMV for bs=1
# Requires https://github.com/pytorch/pytorch/pull/166735
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
else:
# Only source of batch invariance for Hopper is split-k, can disable through
# cuBLAS workspace config
_original_cublas_workspace_cfg = os.environ.get(
"CUBLAS_WORKSPACE_CONFIG", None
)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
_batch_invariant_LIB.impl(
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
)
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
if not is_torch_equal_or_newer("2.10.0.dev"):
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
_original_cublaslt_workspace_size = os.environ.get(
"CUBLASLT_WORKSPACE_SIZE", None
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)
@functools.cache
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False

View File

@ -19,6 +19,9 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -222,6 +225,9 @@ def force_use_trtllm_attention() -> bool | None:
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
if vllm_is_batch_invariant():
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is disabled for batch-invariant")
return False
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)