mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:35:40 +08:00
Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
506ed87e87
commit
d7284a2604
@ -326,7 +326,7 @@ def async_tp_pass_on_test_model(
|
|||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.compilation_config = CompilationConfig(
|
vllm_config.compilation_config = CompilationConfig(
|
||||||
pass_config=PassConfig(
|
pass_config=PassConfig(
|
||||||
enable_async_tp=True,
|
fuse_gemm_comms=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
@ -413,7 +413,7 @@ def test_async_tp_pass_correctness(
|
|||||||
"mode": CompilationMode.VLLM_COMPILE,
|
"mode": CompilationMode.VLLM_COMPILE,
|
||||||
"compile_sizes": [2, 4, 8],
|
"compile_sizes": [2, 4, 8],
|
||||||
"splitting_ops": [],
|
"splitting_ops": [],
|
||||||
"pass_config": {"enable_async_tp": async_tp_enabled},
|
"pass_config": {"fuse_gemm_comms": async_tp_enabled},
|
||||||
}
|
}
|
||||||
|
|
||||||
async_tp_args = [
|
async_tp_args = [
|
||||||
|
|||||||
@ -295,7 +295,7 @@ def all_reduce_fusion_pass_on_test_model(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_fi_allreduce_fusion=True, enable_noop=True
|
fuse_allreduce_rms=True, eliminate_noops=True
|
||||||
)
|
)
|
||||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
||||||
|
|||||||
@ -192,7 +192,7 @@ def test_attn_quant(
|
|||||||
splitting_ops=splitting_ops,
|
splitting_ops=splitting_ops,
|
||||||
# Common
|
# Common
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||||
# Inductor caches custom passes by default as well via uuid
|
# Inductor caches custom passes by default as well via uuid
|
||||||
inductor_compile_config={"force_disable_caches": True},
|
inductor_compile_config={"force_disable_caches": True},
|
||||||
)
|
)
|
||||||
@ -282,9 +282,9 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
|||||||
# Common
|
# Common
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(
|
pass_config=PassConfig(
|
||||||
enable_attn_fusion=True,
|
fuse_attn_quant=True,
|
||||||
enable_noop=True,
|
eliminate_noops=True,
|
||||||
enable_fi_allreduce_fusion=True,
|
fuse_allreduce_rms=True,
|
||||||
),
|
),
|
||||||
# Inductor caches custom passes by default as well via uuid
|
# Inductor caches custom passes by default as well via uuid
|
||||||
inductor_compile_config={"force_disable_caches": True},
|
inductor_compile_config={"force_disable_caches": True},
|
||||||
@ -384,10 +384,10 @@ def test_tp2_attn_quant_async_tp(
|
|||||||
# Common
|
# Common
|
||||||
level=CompilationMode.VLLM_COMPILE,
|
level=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(
|
pass_config=PassConfig(
|
||||||
enable_attn_fusion=True,
|
fuse_attn_quant=True,
|
||||||
enable_noop=True,
|
eliminate_noops=True,
|
||||||
enable_sequence_parallelism=True,
|
enable_sp=True,
|
||||||
enable_async_tp=True,
|
fuse_gemm_comms=True,
|
||||||
),
|
),
|
||||||
# Inductor caches custom passes by default as well via uuid
|
# Inductor caches custom passes by default as well via uuid
|
||||||
inductor_compile_config={"force_disable_caches": True},
|
inductor_compile_config={"force_disable_caches": True},
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
if self.vllm_config.compilation_config.pass_config.fuse_norm_quant:
|
||||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||||
elif RMSNorm.enabled():
|
elif RMSNorm.enabled():
|
||||||
return [
|
return [
|
||||||
@ -183,7 +183,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("seq_len", [16])
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
|
||||||
@pytest.mark.parametrize("dynamic", [False, True])
|
@pytest.mark.parametrize("dynamic", [False, True])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
def test_sequence_parallelism_pass(
|
def test_sequence_parallelism_pass(
|
||||||
@ -193,7 +193,7 @@ def test_sequence_parallelism_pass(
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
enable_fusion: bool,
|
fuse_norm_quant: bool,
|
||||||
dynamic: bool,
|
dynamic: bool,
|
||||||
):
|
):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
@ -211,7 +211,7 @@ def test_sequence_parallelism_pass(
|
|||||||
seq_len,
|
seq_len,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
dtype,
|
dtype,
|
||||||
enable_fusion,
|
fuse_norm_quant,
|
||||||
dynamic,
|
dynamic,
|
||||||
),
|
),
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
@ -229,7 +229,7 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
enable_fusion: bool,
|
fuse_norm_quant: bool,
|
||||||
dynamic: bool,
|
dynamic: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
@ -260,9 +260,9 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||||
custom_ops=custom_ops_list,
|
custom_ops=custom_ops_list,
|
||||||
pass_config=PassConfig(
|
pass_config=PassConfig(
|
||||||
enable_sequence_parallelism=True,
|
enable_sp=True,
|
||||||
enable_fusion=enable_fusion,
|
fuse_norm_quant=fuse_norm_quant,
|
||||||
enable_noop=True,
|
eliminate_noops=True,
|
||||||
),
|
),
|
||||||
) # NoOp needed for fusion
|
) # NoOp needed for fusion
|
||||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
@ -297,7 +297,7 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
sequence_parallelism_pass,
|
sequence_parallelism_pass,
|
||||||
]
|
]
|
||||||
|
|
||||||
if enable_fusion:
|
if fuse_norm_quant:
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
passes_for_backend.append(fusion_pass)
|
passes_for_backend.append(fusion_pass)
|
||||||
|
|
||||||
|
|||||||
@ -122,7 +122,9 @@ def test_full_graph(
|
|||||||
CompilationConfig(
|
CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
custom_ops=["+rms_norm"],
|
custom_ops=["+rms_norm"],
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
pass_config=PassConfig(
|
||||||
|
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
*model_info,
|
*model_info,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -10,8 +11,9 @@ from pydantic import ValidationError
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||||
from vllm.config.compilation import CompilationMode
|
from vllm.config.compilation import CompilationMode, PassConfig
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.logger import _print_warning_once
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||||
|
|
||||||
@ -191,7 +193,7 @@ def test_splitting_ops_dynamic():
|
|||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
)
|
)
|
||||||
@ -206,7 +208,7 @@ def test_splitting_ops_dynamic():
|
|||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
# work around for accessing all attntion ops
|
# work around for accessing all attntion ops
|
||||||
@ -219,7 +221,7 @@ def test_splitting_ops_dynamic():
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
use_inductor_graph_partition=True,
|
use_inductor_graph_partition=True,
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
)
|
)
|
||||||
@ -227,7 +229,7 @@ def test_splitting_ops_dynamic():
|
|||||||
# With inductor graph partition, attn_fusion and splitting_ops
|
# With inductor graph partition, attn_fusion and splitting_ops
|
||||||
# work together. Default splitting_ops include attention ops.
|
# work together. Default splitting_ops include attention ops.
|
||||||
assert config.compilation_config.splitting_ops_contain_attention()
|
assert config.compilation_config.splitting_ops_contain_attention()
|
||||||
# enable_attn_fusion is directly supported under
|
# fuse_attn_quant is directly supported under
|
||||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||||
# is unchanged.
|
# is unchanged.
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
@ -301,7 +303,7 @@ def test_should_split():
|
|||||||
"cudagraph_capture_sizes",
|
"cudagraph_capture_sizes",
|
||||||
"max_cudagraph_capture_size",
|
"max_cudagraph_capture_size",
|
||||||
"tp_size",
|
"tp_size",
|
||||||
"enable_sequence_parallelism",
|
"enable_sp",
|
||||||
"max_num_batched_tokens",
|
"max_num_batched_tokens",
|
||||||
"cudagraph_mode",
|
"cudagraph_mode",
|
||||||
"expected_max_size",
|
"expected_max_size",
|
||||||
@ -339,7 +341,7 @@ def test_cudagraph_sizes_post_init(
|
|||||||
cudagraph_capture_sizes,
|
cudagraph_capture_sizes,
|
||||||
max_cudagraph_capture_size,
|
max_cudagraph_capture_size,
|
||||||
tp_size,
|
tp_size,
|
||||||
enable_sequence_parallelism,
|
enable_sp,
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
cudagraph_mode,
|
cudagraph_mode,
|
||||||
expected_max_size,
|
expected_max_size,
|
||||||
@ -355,11 +357,12 @@ def test_cudagraph_sizes_post_init(
|
|||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
cudagraph_capture_sizes=cudagraph_capture_sizes,
|
||||||
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
max_cudagraph_capture_size=max_cudagraph_capture_size,
|
||||||
pass_config={
|
pass_config=PassConfig(
|
||||||
"enable_sequence_parallelism": enable_sequence_parallelism,
|
enable_sp=enable_sp,
|
||||||
"enable_fusion": True,
|
fuse_norm_quant=True,
|
||||||
"enable_noop": True,
|
fuse_act_quant=True,
|
||||||
},
|
eliminate_noops=True,
|
||||||
|
),
|
||||||
cudagraph_mode=cudagraph_mode,
|
cudagraph_mode=cudagraph_mode,
|
||||||
)
|
)
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
@ -375,3 +378,53 @@ def test_cudagraph_sizes_post_init(
|
|||||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||||
== expected_max_size
|
== expected_max_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pass_config_deprecation(caplog_vllm):
|
||||||
|
caplog_vllm.set_level(logging.WARNING)
|
||||||
|
|
||||||
|
# Clear cache to ensure warnings are re-issued
|
||||||
|
_print_warning_once.cache_clear()
|
||||||
|
|
||||||
|
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
||||||
|
caplog_vllm.clear()
|
||||||
|
config = PassConfig(enable_fusion=True)
|
||||||
|
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||||
|
assert config.fuse_norm_quant is True
|
||||||
|
assert config.fuse_act_quant is True
|
||||||
|
assert config.enable_fusion is None
|
||||||
|
|
||||||
|
# Test enable_attn_fusion -> fuse_attn_quant
|
||||||
|
caplog_vllm.clear()
|
||||||
|
config = PassConfig(enable_attn_fusion=True)
|
||||||
|
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||||
|
assert config.fuse_attn_quant is True
|
||||||
|
assert config.enable_attn_fusion is None
|
||||||
|
|
||||||
|
# Test enable_noop -> eliminate_noops
|
||||||
|
caplog_vllm.clear()
|
||||||
|
config = PassConfig(enable_noop=True)
|
||||||
|
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||||
|
assert config.eliminate_noops is True
|
||||||
|
assert config.enable_noop is None
|
||||||
|
|
||||||
|
# Test enable_sequence_parallelism -> enable_sp
|
||||||
|
caplog_vllm.clear()
|
||||||
|
config = PassConfig(enable_sequence_parallelism=True)
|
||||||
|
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||||
|
assert config.enable_sp is True
|
||||||
|
assert config.enable_sequence_parallelism is None
|
||||||
|
|
||||||
|
# Test enable_async_tp -> fuse_gemm_comms
|
||||||
|
caplog_vllm.clear()
|
||||||
|
config = PassConfig(enable_async_tp=True)
|
||||||
|
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||||
|
assert config.fuse_gemm_comms is True
|
||||||
|
assert config.enable_async_tp is None
|
||||||
|
|
||||||
|
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||||
|
caplog_vllm.clear()
|
||||||
|
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||||
|
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||||
|
assert config.fuse_allreduce_rms is True
|
||||||
|
assert config.enable_fi_allreduce_fusion is None
|
||||||
|
|||||||
@ -223,7 +223,11 @@ def test_fix_functionalization(
|
|||||||
model_config=ModelConfig(dtype=dtype),
|
model_config=ModelConfig(dtype=dtype),
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
custom_ops=["all"],
|
custom_ops=["all"],
|
||||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
|
pass_config=PassConfig(
|
||||||
|
fuse_norm_quant=do_fusion,
|
||||||
|
fuse_act_quant=do_fusion,
|
||||||
|
eliminate_noops=True,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -159,7 +159,9 @@ def test_fusion_rmsnorm_quant(
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
custom_ops=custom_ops,
|
custom_ops=custom_ops,
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
pass_config=PassConfig(
|
||||||
|
fuse_norm_quant=True, fuse_act_quant=True, eliminate_noops=True
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
|
|||||||
@ -373,7 +373,7 @@ def test_attention_quant_pattern(
|
|||||||
|
|
||||||
# Run model with attn fusion enabled
|
# Run model with attn fusion enabled
|
||||||
vllm_config.compilation_config.pass_config = PassConfig(
|
vllm_config.compilation_config.pass_config = PassConfig(
|
||||||
enable_attn_fusion=True, enable_noop=True
|
fuse_attn_quant=True, eliminate_noops=True
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
set_current_vllm_config(vllm_config),
|
set_current_vllm_config(vllm_config),
|
||||||
|
|||||||
@ -51,7 +51,7 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
|||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(enable_noop=True),
|
pass_config=PassConfig(eliminate_noops=True),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
@ -99,7 +99,7 @@ def test_non_noop_slice_preserved():
|
|||||||
vllm_config = VllmConfig(
|
vllm_config = VllmConfig(
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(enable_noop=True),
|
pass_config=PassConfig(eliminate_noops=True),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
|
|||||||
@ -64,8 +64,11 @@ def test_pass_manager_uuid(callable):
|
|||||||
|
|
||||||
# UUID should be different due to config change
|
# UUID should be different due to config change
|
||||||
config2 = copy.deepcopy(config)
|
config2 = copy.deepcopy(config)
|
||||||
config2.compilation_config.pass_config.enable_fusion = (
|
config2.compilation_config.pass_config.fuse_norm_quant = (
|
||||||
not config2.compilation_config.pass_config.enable_fusion
|
not config2.compilation_config.pass_config.fuse_norm_quant
|
||||||
|
)
|
||||||
|
config2.compilation_config.pass_config.fuse_act_quant = (
|
||||||
|
not config2.compilation_config.pass_config.fuse_act_quant
|
||||||
)
|
)
|
||||||
pass_manager3 = PostGradPassManager()
|
pass_manager3 = PostGradPassManager()
|
||||||
pass_manager3.configure(config2)
|
pass_manager3.configure(config2)
|
||||||
|
|||||||
@ -140,7 +140,7 @@ def test_qk_norm_rope_fusion(
|
|||||||
custom_ops=custom_ops,
|
custom_ops=custom_ops,
|
||||||
pass_config=PassConfig(
|
pass_config=PassConfig(
|
||||||
enable_qk_norm_rope_fusion=True,
|
enable_qk_norm_rope_fusion=True,
|
||||||
enable_noop=True,
|
eliminate_noops=True,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -168,7 +168,7 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
custom_ops=custom_ops,
|
custom_ops=custom_ops,
|
||||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
pass_config=PassConfig(fuse_act_quant=True, eliminate_noops=True),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -32,7 +32,8 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
|||||||
class ParallelSetup(NamedTuple):
|
class ParallelSetup(NamedTuple):
|
||||||
tp_size: int
|
tp_size: int
|
||||||
pp_size: int
|
pp_size: int
|
||||||
enable_fusion: bool
|
fuse_norm_quant: bool
|
||||||
|
fuse_act_quant: bool
|
||||||
eager_mode: bool
|
eager_mode: bool
|
||||||
chunked_prefill: bool
|
chunked_prefill: bool
|
||||||
|
|
||||||
@ -66,7 +67,8 @@ class SPTestSettings:
|
|||||||
ParallelSetup(
|
ParallelSetup(
|
||||||
tp_size=tp_base,
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
enable_fusion=False,
|
fuse_norm_quant=False,
|
||||||
|
fuse_act_quant=False,
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val,
|
chunked_prefill=chunked_prefill_val,
|
||||||
)
|
)
|
||||||
@ -97,7 +99,8 @@ class SPTestSettings:
|
|||||||
ParallelSetup(
|
ParallelSetup(
|
||||||
tp_size=tp_base,
|
tp_size=tp_base,
|
||||||
pp_size=pp_multiplier * pp_base,
|
pp_size=pp_multiplier * pp_base,
|
||||||
enable_fusion=False,
|
fuse_norm_quant=False,
|
||||||
|
fuse_act_quant=False,
|
||||||
eager_mode=eager_mode_val,
|
eager_mode=eager_mode_val,
|
||||||
chunked_prefill=chunked_prefill_val,
|
chunked_prefill=chunked_prefill_val,
|
||||||
)
|
)
|
||||||
@ -126,7 +129,8 @@ class SPTestSettings:
|
|||||||
ParallelSetup(
|
ParallelSetup(
|
||||||
tp_size=tp_base,
|
tp_size=tp_base,
|
||||||
pp_size=pp_base,
|
pp_size=pp_base,
|
||||||
enable_fusion=fusion_val,
|
fuse_norm_quant=fusion_val,
|
||||||
|
fuse_act_quant=fusion_val,
|
||||||
eager_mode=True,
|
eager_mode=True,
|
||||||
chunked_prefill=False,
|
chunked_prefill=False,
|
||||||
)
|
)
|
||||||
@ -162,7 +166,7 @@ def _compare_sp(
|
|||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available: int,
|
num_gpus_available: int,
|
||||||
use_inductor_graph_partition: bool,
|
use_inductor_graph_partition: bool,
|
||||||
enable_async_tp: bool,
|
fuse_gemm_comms: bool,
|
||||||
*,
|
*,
|
||||||
method: Literal["generate", "encode"],
|
method: Literal["generate", "encode"],
|
||||||
is_multimodal: bool,
|
is_multimodal: bool,
|
||||||
@ -170,7 +174,8 @@ def _compare_sp(
|
|||||||
(
|
(
|
||||||
tp_size,
|
tp_size,
|
||||||
pp_size,
|
pp_size,
|
||||||
enable_fusion,
|
fuse_norm_quant,
|
||||||
|
fuse_act_quant,
|
||||||
eager_mode,
|
eager_mode,
|
||||||
chunked_prefill,
|
chunked_prefill,
|
||||||
) = parallel_setup
|
) = parallel_setup
|
||||||
@ -248,10 +253,11 @@ def _compare_sp(
|
|||||||
"mode": CompilationMode.VLLM_COMPILE,
|
"mode": CompilationMode.VLLM_COMPILE,
|
||||||
"compile_sizes": [4, 8],
|
"compile_sizes": [4, 8],
|
||||||
"pass_config": {
|
"pass_config": {
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sp": True,
|
||||||
"enable_async_tp": enable_async_tp,
|
"fuse_gemm_comms": fuse_gemm_comms,
|
||||||
"enable_fusion": enable_fusion,
|
"fuse_norm_quant": fuse_norm_quant,
|
||||||
"enable_noop": True,
|
"fuse_act_quant": fuse_act_quant,
|
||||||
|
"eliminate_noops": True,
|
||||||
},
|
},
|
||||||
"use_inductor_graph_partition": use_inductor_graph_partition,
|
"use_inductor_graph_partition": use_inductor_graph_partition,
|
||||||
}
|
}
|
||||||
@ -309,7 +315,7 @@ SP_TEST_MODELS = [
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||||
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
|
@pytest.mark.parametrize("fuse_gemm_comms", [False]) # TODO: enable async TP
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_tp_sp_generation(
|
def test_tp_sp_generation(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -319,7 +325,7 @@ def test_tp_sp_generation(
|
|||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
use_inductor_graph_partition: bool,
|
use_inductor_graph_partition: bool,
|
||||||
enable_async_tp: bool,
|
fuse_gemm_comms: bool,
|
||||||
):
|
):
|
||||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
@ -328,7 +334,7 @@ def test_tp_sp_generation(
|
|||||||
if (
|
if (
|
||||||
"fp8" in model_id.lower()
|
"fp8" in model_id.lower()
|
||||||
and current_platform.get_device_capability() < (9, 0)
|
and current_platform.get_device_capability() < (9, 0)
|
||||||
and (not enable_async_tp)
|
and (not fuse_gemm_comms)
|
||||||
):
|
):
|
||||||
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
||||||
|
|
||||||
@ -340,7 +346,7 @@ def test_tp_sp_generation(
|
|||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
use_inductor_graph_partition,
|
use_inductor_graph_partition,
|
||||||
enable_async_tp=enable_async_tp,
|
fuse_gemm_comms=fuse_gemm_comms,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=False,
|
is_multimodal=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1023,17 +1023,17 @@ def test_vllm_config_explicit_overrides():
|
|||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
|
|
||||||
# Explicit pass config flags to override defaults
|
# Explicit pass config flags to override defaults
|
||||||
pass_config = PassConfig(enable_noop=True, enable_attn_fusion=True)
|
pass_config = PassConfig(eliminate_noops=True, fuse_attn_quant=True)
|
||||||
compilation_config = CompilationConfig(pass_config=pass_config)
|
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
optimization_level=OptimizationLevel.O0,
|
optimization_level=OptimizationLevel.O0,
|
||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
)
|
)
|
||||||
assert config.compilation_config.pass_config.enable_noop is True
|
assert config.compilation_config.pass_config.eliminate_noops is True
|
||||||
assert config.compilation_config.pass_config.enable_attn_fusion is True
|
assert config.compilation_config.pass_config.fuse_attn_quant is True
|
||||||
|
|
||||||
# Explicit cudagraph mode override on quantized model at O2
|
# Explicit cudagraph mode override on quantized model at O2
|
||||||
pass_config = PassConfig(enable_async_tp=True)
|
pass_config = PassConfig(fuse_gemm_comms=True)
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
|
cudagraph_mode=CUDAGraphMode.NONE, pass_config=pass_config
|
||||||
)
|
)
|
||||||
@ -1043,7 +1043,7 @@ def test_vllm_config_explicit_overrides():
|
|||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
)
|
)
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||||
assert config.compilation_config.pass_config.enable_async_tp is True
|
assert config.compilation_config.pass_config.fuse_gemm_comms is True
|
||||||
# Mode should still use default for O2
|
# Mode should still use default for O2
|
||||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
|
|
||||||
@ -1093,7 +1093,7 @@ def test_vllm_config_explicit_overrides():
|
|||||||
compilation_config=compilation_config,
|
compilation_config=compilation_config,
|
||||||
)
|
)
|
||||||
# Explicit override should be respected
|
# Explicit override should be respected
|
||||||
assert config.compilation_config.pass_config.enable_noop is False
|
assert config.compilation_config.pass_config.eliminate_noops is False
|
||||||
# Other fields should still use defaults
|
# Other fields should still use defaults
|
||||||
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
|||||||
@ -92,22 +92,23 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
|
|
||||||
# Set the current vllm config to allow tracing CustomOp instances
|
# Set the current vllm config to allow tracing CustomOp instances
|
||||||
with set_current_vllm_config(config, check_compile=False):
|
with set_current_vllm_config(config, check_compile=False):
|
||||||
if self.pass_config.enable_noop:
|
if self.pass_config.eliminate_noops:
|
||||||
self.passes += [NoOpEliminationPass(config)]
|
self.passes += [NoOpEliminationPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_sequence_parallelism:
|
if self.pass_config.enable_sp:
|
||||||
self.passes += [SequenceParallelismPass(config)]
|
self.passes += [SequenceParallelismPass(config)]
|
||||||
if self.pass_config.enable_async_tp:
|
if self.pass_config.fuse_gemm_comms:
|
||||||
self.passes += [AsyncTPPass(config)]
|
self.passes += [AsyncTPPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_fi_allreduce_fusion:
|
if self.pass_config.fuse_allreduce_rms:
|
||||||
self.passes += [AllReduceFusionPass(config)]
|
self.passes += [AllReduceFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_fusion:
|
if self.pass_config.fuse_norm_quant:
|
||||||
self.passes += [RMSNormQuantFusionPass(config)]
|
self.passes += [RMSNormQuantFusionPass(config)]
|
||||||
|
if self.pass_config.fuse_act_quant:
|
||||||
self.passes += [ActivationQuantFusionPass(config)]
|
self.passes += [ActivationQuantFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_attn_fusion:
|
if self.pass_config.fuse_attn_quant:
|
||||||
self.passes += [AttnFusionPass(config)]
|
self.passes += [AttnFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config, handle_deprecated
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
@ -105,18 +105,43 @@ class PassConfig:
|
|||||||
improper state.
|
improper state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# New flags
|
||||||
|
fuse_norm_quant: bool = Field(default=None)
|
||||||
|
"""Fuse the custom RMSNorm + quant ops."""
|
||||||
|
fuse_act_quant: bool = Field(default=None)
|
||||||
|
"""Fuse the custom SiluMul + quant ops."""
|
||||||
|
fuse_attn_quant: bool = Field(default=None)
|
||||||
|
"""Fuse the custom attention + quant ops."""
|
||||||
|
eliminate_noops: bool = Field(default=None)
|
||||||
|
"""Eliminate no-op ops."""
|
||||||
|
enable_sp: bool = Field(default=None)
|
||||||
|
"""Enable sequence parallelism."""
|
||||||
|
fuse_gemm_comms: bool = Field(default=None)
|
||||||
|
"""Enable async TP."""
|
||||||
|
fuse_allreduce_rms: bool = Field(default=None)
|
||||||
|
"""Enable flashinfer allreduce fusion."""
|
||||||
|
|
||||||
|
# Deprecated flags
|
||||||
enable_fusion: bool = Field(default=None)
|
enable_fusion: bool = Field(default=None)
|
||||||
"""Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass."""
|
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
|
||||||
|
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
|
||||||
|
"""
|
||||||
enable_attn_fusion: bool = Field(default=None)
|
enable_attn_fusion: bool = Field(default=None)
|
||||||
"""Whether to enable the custom attention+quant fusion pass."""
|
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
|
||||||
|
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||||
enable_noop: bool = Field(default=None)
|
enable_noop: bool = Field(default=None)
|
||||||
"""Whether to enable the custom no-op elimination pass."""
|
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
|
||||||
|
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||||
enable_sequence_parallelism: bool = Field(default=None)
|
enable_sequence_parallelism: bool = Field(default=None)
|
||||||
"""Whether to enable sequence parallelism."""
|
"""Deprecated in: v0.12.0. Use enable_sp instead.
|
||||||
|
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||||
enable_async_tp: bool = Field(default=None)
|
enable_async_tp: bool = Field(default=None)
|
||||||
"""Whether to enable async TP."""
|
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
|
||||||
|
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||||
enable_fi_allreduce_fusion: bool = Field(default=None)
|
enable_fi_allreduce_fusion: bool = Field(default=None)
|
||||||
"""Whether to enable flashinfer allreduce fusion."""
|
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
|
||||||
|
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
||||||
|
|
||||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||||
"""The threshold of the communicated tensor sizes under which
|
"""The threshold of the communicated tensor sizes under which
|
||||||
vllm should use flashinfer fused allreduce. Specified as a
|
vllm should use flashinfer fused allreduce. Specified as a
|
||||||
@ -136,7 +161,7 @@ class PassConfig:
|
|||||||
},
|
},
|
||||||
}, where key is the device capability"""
|
}, where key is the device capability"""
|
||||||
enable_qk_norm_rope_fusion: bool = False
|
enable_qk_norm_rope_fusion: bool = False
|
||||||
"""Whether to enable the fused Q/K RMSNorm + RoPE pass."""
|
"""Enable fused Q/K RMSNorm + RoPE pass."""
|
||||||
|
|
||||||
# TODO(luka) better pass enabling system.
|
# TODO(luka) better pass enabling system.
|
||||||
|
|
||||||
@ -174,6 +199,13 @@ class PassConfig:
|
|||||||
return InductorPass.hash_dict(asdict(self))
|
return InductorPass.hash_dict(asdict(self))
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
|
"fuse_norm_quant",
|
||||||
|
"fuse_act_quant",
|
||||||
|
"fuse_attn_quant",
|
||||||
|
"eliminate_noops",
|
||||||
|
"enable_sp",
|
||||||
|
"fuse_gemm_comms",
|
||||||
|
"fuse_allreduce_rms",
|
||||||
"enable_fusion",
|
"enable_fusion",
|
||||||
"enable_attn_fusion",
|
"enable_attn_fusion",
|
||||||
"enable_noop",
|
"enable_noop",
|
||||||
@ -190,18 +222,71 @@ class PassConfig:
|
|||||||
return handler(value)
|
return handler(value)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if not self.enable_noop:
|
# Handle deprecation and defaults
|
||||||
if self.enable_fusion:
|
|
||||||
|
# Map old flags to new flags and issue warnings
|
||||||
|
handle_deprecated(
|
||||||
|
self,
|
||||||
|
"enable_fusion",
|
||||||
|
["fuse_norm_quant", "fuse_act_quant"],
|
||||||
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
|
)
|
||||||
|
|
||||||
|
handle_deprecated(
|
||||||
|
self,
|
||||||
|
"enable_attn_fusion",
|
||||||
|
"fuse_attn_quant",
|
||||||
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
|
)
|
||||||
|
|
||||||
|
handle_deprecated(
|
||||||
|
self,
|
||||||
|
"enable_sequence_parallelism",
|
||||||
|
"enable_sp",
|
||||||
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
|
)
|
||||||
|
|
||||||
|
handle_deprecated(
|
||||||
|
self,
|
||||||
|
"enable_async_tp",
|
||||||
|
"fuse_gemm_comms",
|
||||||
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
|
)
|
||||||
|
|
||||||
|
handle_deprecated(
|
||||||
|
self,
|
||||||
|
"enable_fi_allreduce_fusion",
|
||||||
|
"fuse_allreduce_rms",
|
||||||
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
|
)
|
||||||
|
|
||||||
|
handle_deprecated(
|
||||||
|
self,
|
||||||
|
"enable_noop",
|
||||||
|
"eliminate_noops",
|
||||||
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force old flags to None to ensure they are not used
|
||||||
|
self.enable_fusion = None
|
||||||
|
self.enable_attn_fusion = None
|
||||||
|
self.enable_noop = None
|
||||||
|
self.enable_sequence_parallelism = None
|
||||||
|
self.enable_async_tp = None
|
||||||
|
self.enable_fi_allreduce_fusion = None
|
||||||
|
|
||||||
|
if not self.eliminate_noops:
|
||||||
|
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Fusion enabled but reshape elimination disabled. "
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
|
"RMSNorm/SiluMul + quant (fp8) fusion might not work"
|
||||||
)
|
)
|
||||||
if self.enable_attn_fusion:
|
if self.fuse_attn_quant:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Fusion enabled but reshape elimination disabled. "
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
"Attention + quant (fp8) fusion might not work"
|
"Attention + quant (fp8) fusion might not work"
|
||||||
)
|
)
|
||||||
if self.enable_fi_allreduce_fusion:
|
if self.fuse_allreduce_rms:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Fusion enabled but reshape elimination disabled. "
|
"Fusion enabled but reshape elimination disabled. "
|
||||||
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
"Allreduce + rms norm + quant (fp8) fusion might not work"
|
||||||
@ -873,7 +958,7 @@ class CompilationConfig:
|
|||||||
self.set_splitting_ops_for_inductor_graph_partition()
|
self.set_splitting_ops_for_inductor_graph_partition()
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.pass_config.enable_attn_fusion:
|
if self.pass_config.fuse_attn_quant:
|
||||||
# here use_inductor_graph_partition is False
|
# here use_inductor_graph_partition is False
|
||||||
self.set_splitting_ops_for_attn_fusion()
|
self.set_splitting_ops_for_attn_fusion()
|
||||||
return
|
return
|
||||||
@ -915,12 +1000,12 @@ class CompilationConfig:
|
|||||||
self.splitting_ops = list(self._attention_ops)
|
self.splitting_ops = list(self._attention_ops)
|
||||||
|
|
||||||
def set_splitting_ops_for_attn_fusion(self):
|
def set_splitting_ops_for_attn_fusion(self):
|
||||||
assert self.pass_config.enable_attn_fusion
|
assert self.pass_config.fuse_attn_quant
|
||||||
if self.splitting_ops is None:
|
if self.splitting_ops is None:
|
||||||
self.splitting_ops = []
|
self.splitting_ops = []
|
||||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"enable_attn_fusion is incompatible with piecewise "
|
"fuse_attn_quant is incompatible with piecewise "
|
||||||
"cudagraph when use_inductor_graph_partition is off. "
|
"cudagraph when use_inductor_graph_partition is off. "
|
||||||
"In this case, splitting_ops will be set to empty "
|
"In this case, splitting_ops will be set to empty "
|
||||||
"list, and cudagraph_mode will be set to FULL. "
|
"list, and cudagraph_mode will be set to FULL. "
|
||||||
@ -931,8 +1016,7 @@ class CompilationConfig:
|
|||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
|
|
||||||
assert not self.splitting_ops_contain_attention(), (
|
assert not self.splitting_ops_contain_attention(), (
|
||||||
"attention ops should not be in splitting_ops "
|
"attention ops should not be in splitting_ops when fuse_attn_quant is True"
|
||||||
"when enable_attn_fusion is True"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def splitting_ops_contain_attention(self) -> bool:
|
def splitting_ops_contain_attention(self) -> bool:
|
||||||
@ -1008,7 +1092,7 @@ class CompilationConfig:
|
|||||||
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
self, uniform_decode_query_len: int, tensor_parallel_size: int
|
||||||
):
|
):
|
||||||
multiple_of = uniform_decode_query_len
|
multiple_of = uniform_decode_query_len
|
||||||
if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
|
if tensor_parallel_size > 1 and self.pass_config.enable_sp:
|
||||||
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
|
multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
|
||||||
if (
|
if (
|
||||||
multiple_of % uniform_decode_query_len != 0
|
multiple_of % uniform_decode_query_len != 0
|
||||||
|
|||||||
@ -19,6 +19,10 @@ import torch
|
|||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from typing_extensions import runtime_checkable
|
from typing_extensions import runtime_checkable
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from _typeshed import DataclassInstance
|
from _typeshed import DataclassInstance
|
||||||
else:
|
else:
|
||||||
@ -293,3 +297,28 @@ def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, ob
|
|||||||
def hash_factors(items: dict[str, object]) -> str:
|
def hash_factors(items: dict[str, object]) -> str:
|
||||||
"""Return a SHA-256 hex digest of the canonical items structure."""
|
"""Return a SHA-256 hex digest of the canonical items structure."""
|
||||||
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_deprecated(
|
||||||
|
config: ConfigT,
|
||||||
|
old_name: str,
|
||||||
|
new_name_or_names: str | list[str],
|
||||||
|
removal_version: str,
|
||||||
|
) -> None:
|
||||||
|
old_val = getattr(config, old_name)
|
||||||
|
if old_val is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(new_name_or_names, str):
|
||||||
|
new_names = [new_name_or_names]
|
||||||
|
else:
|
||||||
|
new_names = new_name_or_names
|
||||||
|
|
||||||
|
msg = (
|
||||||
|
f"{old_name} is deprecated and will be removed in {removal_version}. "
|
||||||
|
f"Use {', '.join(new_names)} instead."
|
||||||
|
)
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
|
for new_name in new_names:
|
||||||
|
setattr(config, new_name, old_val)
|
||||||
|
|||||||
@ -83,22 +83,33 @@ IS_DENSE = False
|
|||||||
# See https://github.com/vllm-project/vllm/issues/25689.
|
# See https://github.com/vllm-project/vllm/issues/25689.
|
||||||
|
|
||||||
|
|
||||||
def enable_fusion(cfg: "VllmConfig") -> bool:
|
def enable_norm_fusion(cfg: "VllmConfig") -> bool:
|
||||||
"""Returns True if RMS norm or quant FP8 is enabled."""
|
"""Enable if either RMS norm or quant FP8 custom op is active;
|
||||||
|
otherwise Inductor handles fusion."""
|
||||||
|
|
||||||
return cfg.compilation_config.is_custom_op_enabled(
|
return cfg.compilation_config.is_custom_op_enabled(
|
||||||
"rms_norm"
|
"rms_norm"
|
||||||
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||||
|
|
||||||
|
|
||||||
|
def enable_act_fusion(cfg: "VllmConfig") -> bool:
|
||||||
|
"""Enable if either SiLU+Mul or quant FP8 custom op is active;
|
||||||
|
otherwise Inductor handles fusion."""
|
||||||
|
return cfg.compilation_config.is_custom_op_enabled(
|
||||||
|
"silu_and_mul"
|
||||||
|
) or cfg.compilation_config.is_custom_op_enabled("quant_fp8")
|
||||||
|
|
||||||
|
|
||||||
OPTIMIZATION_LEVEL_00 = {
|
OPTIMIZATION_LEVEL_00 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
"pass_config": {
|
"pass_config": {
|
||||||
"enable_noop": False,
|
"eliminate_noops": False,
|
||||||
"enable_fusion": False,
|
"fuse_norm_quant": False,
|
||||||
"enable_fi_allreduce_fusion": False,
|
"fuse_act_quant": False,
|
||||||
"enable_attn_fusion": False,
|
"fuse_allreduce_rms": False,
|
||||||
"enable_sequence_parallelism": False,
|
"fuse_attn_quant": False,
|
||||||
"enable_async_tp": False,
|
"enable_sp": False,
|
||||||
|
"fuse_gemm_comms": False,
|
||||||
},
|
},
|
||||||
"cudagraph_mode": CUDAGraphMode.NONE,
|
"cudagraph_mode": CUDAGraphMode.NONE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
@ -107,12 +118,13 @@ OPTIMIZATION_LEVEL_00 = {
|
|||||||
OPTIMIZATION_LEVEL_01 = {
|
OPTIMIZATION_LEVEL_01 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
"pass_config": {
|
"pass_config": {
|
||||||
"enable_noop": True,
|
"eliminate_noops": True,
|
||||||
"enable_fusion": enable_fusion,
|
"fuse_norm_quant": enable_norm_fusion,
|
||||||
"enable_fi_allreduce_fusion": False,
|
"fuse_act_quant": enable_act_fusion,
|
||||||
"enable_attn_fusion": False,
|
"fuse_allreduce_rms": False,
|
||||||
"enable_sequence_parallelism": False,
|
"fuse_attn_quant": False,
|
||||||
"enable_async_tp": False,
|
"enable_sp": False,
|
||||||
|
"fuse_gemm_comms": False,
|
||||||
},
|
},
|
||||||
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
@ -121,12 +133,13 @@ OPTIMIZATION_LEVEL_01 = {
|
|||||||
OPTIMIZATION_LEVEL_02 = {
|
OPTIMIZATION_LEVEL_02 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
"pass_config": {
|
"pass_config": {
|
||||||
"enable_noop": True,
|
"eliminate_noops": True,
|
||||||
"enable_fusion": enable_fusion,
|
"fuse_norm_quant": enable_norm_fusion,
|
||||||
"enable_fi_allreduce_fusion": False,
|
"fuse_act_quant": enable_act_fusion,
|
||||||
"enable_attn_fusion": IS_QUANTIZED,
|
"fuse_allreduce_rms": False,
|
||||||
"enable_sequence_parallelism": IS_DENSE,
|
"fuse_attn_quant": IS_QUANTIZED,
|
||||||
"enable_async_tp": IS_DENSE,
|
"enable_sp": IS_DENSE,
|
||||||
|
"fuse_gemm_comms": IS_DENSE,
|
||||||
},
|
},
|
||||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
@ -135,12 +148,13 @@ OPTIMIZATION_LEVEL_02 = {
|
|||||||
OPTIMIZATION_LEVEL_03 = {
|
OPTIMIZATION_LEVEL_03 = {
|
||||||
"compilation_config": {
|
"compilation_config": {
|
||||||
"pass_config": {
|
"pass_config": {
|
||||||
"enable_noop": True,
|
"eliminate_noops": True,
|
||||||
"enable_fusion": enable_fusion,
|
"fuse_norm_quant": enable_norm_fusion,
|
||||||
"enable_fi_allreduce_fusion": False,
|
"fuse_act_quant": enable_act_fusion,
|
||||||
"enable_attn_fusion": IS_QUANTIZED,
|
"fuse_allreduce_rms": False,
|
||||||
"enable_sequence_parallelism": IS_DENSE,
|
"fuse_attn_quant": IS_QUANTIZED,
|
||||||
"enable_async_tp": IS_DENSE,
|
"enable_sp": IS_DENSE,
|
||||||
|
"fuse_gemm_comms": IS_DENSE,
|
||||||
},
|
},
|
||||||
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
"cudagraph_mode": CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||||
"use_inductor_graph_partition": False,
|
"use_inductor_graph_partition": False,
|
||||||
@ -645,9 +659,9 @@ class VllmConfig:
|
|||||||
|
|
||||||
# async tp is built on top of sequence parallelism
|
# async tp is built on top of sequence parallelism
|
||||||
# and requires it to be enabled.
|
# and requires it to be enabled.
|
||||||
if self.compilation_config.pass_config.enable_async_tp:
|
if self.compilation_config.pass_config.fuse_gemm_comms:
|
||||||
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
self.compilation_config.pass_config.enable_sp = True
|
||||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
if self.compilation_config.pass_config.enable_sp:
|
||||||
if "-rms_norm" in self.compilation_config.custom_ops:
|
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"RMS norm force disabled, sequence parallelism might break"
|
"RMS norm force disabled, sequence parallelism might break"
|
||||||
@ -797,7 +811,7 @@ class VllmConfig:
|
|||||||
# Do this after all the updates to compilation_config.mode
|
# Do this after all the updates to compilation_config.mode
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
if self.compilation_config.pass_config.enable_sp:
|
||||||
# With pipeline parallelism or dynamo partitioning,
|
# With pipeline parallelism or dynamo partitioning,
|
||||||
# native rms norm tracing errors due to incorrect residual shape.
|
# native rms norm tracing errors due to incorrect residual shape.
|
||||||
# Use custom rms norm to unblock. In the future,
|
# Use custom rms norm to unblock. In the future,
|
||||||
@ -1062,7 +1076,7 @@ class VllmConfig:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
self.parallel_config.tensor_parallel_size > 1
|
self.parallel_config.tensor_parallel_size > 1
|
||||||
and self.compilation_config.pass_config.enable_sequence_parallelism
|
and self.compilation_config.pass_config.enable_sp
|
||||||
):
|
):
|
||||||
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
|
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
|
||||||
cudagraph_capture_sizes
|
cudagraph_capture_sizes
|
||||||
|
|||||||
@ -2417,10 +2417,7 @@ class GPUModelRunner(
|
|||||||
# Pad tokens to multiple of tensor_parallel_size when
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
# enabled collective fusion for SP
|
# enabled collective fusion for SP
|
||||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||||
if (
|
if self.compilation_config.pass_config.enable_sp and tp_size > 1:
|
||||||
self.compilation_config.pass_config.enable_sequence_parallelism
|
|
||||||
and tp_size > 1
|
|
||||||
):
|
|
||||||
return round_up(num_scheduled_tokens, tp_size)
|
return round_up(num_scheduled_tokens, tp_size)
|
||||||
return num_scheduled_tokens
|
return num_scheduled_tokens
|
||||||
|
|
||||||
|
|||||||
@ -552,7 +552,7 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
parallel_config.pipeline_parallel_size > 1
|
parallel_config.pipeline_parallel_size > 1
|
||||||
and compilation_config.pass_config.enable_sequence_parallelism
|
and compilation_config.pass_config.enable_sp
|
||||||
and forward_pass
|
and forward_pass
|
||||||
):
|
):
|
||||||
# currently only supported by V1 GPUModelRunner
|
# currently only supported by V1 GPUModelRunner
|
||||||
|
|||||||
@ -342,7 +342,7 @@ def is_residual_scattered_for_sp(
|
|||||||
partition), SP is always applied
|
partition), SP is always applied
|
||||||
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
||||||
"""
|
"""
|
||||||
if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism:
|
if not vllm_config.compilation_config.pass_config.enable_sp:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
tp = vllm_config.parallel_config.tensor_parallel_size
|
tp = vllm_config.parallel_config.tensor_parallel_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user