mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 05:37:02 +08:00
[Bug] Fix batch invariant in torch 2.10 (#30907)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
eee600c34f
commit
6628758233
@ -933,30 +933,26 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_MODE = True
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
|
||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
if (
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(80)
|
||||
or current_platform.is_device_capability(89)
|
||||
):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
else:
|
||||
# Only source of batch invariance for Hopper is split-k, can disable through
|
||||
# cuBLAS workspace config
|
||||
_original_cublas_workspace_cfg = os.environ.get(
|
||||
"CUBLAS_WORKSPACE_CONFIG", None
|
||||
)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
if (
|
||||
current_platform.is_device_capability_family(100)
|
||||
or current_platform.is_device_capability(80)
|
||||
or current_platform.is_device_capability(89)
|
||||
):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
else:
|
||||
# Only source of batch invariance for Hopper is split-k, can disable through
|
||||
# cuBLAS workspace config
|
||||
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user