mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 09:42:16 +08:00
[compile] Enable sequence parallelism matching w/o custom ops enabled (#27126)
Signed-off-by: angelayi <yiangela7@gmail.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: ProExpertProg <lgovedic@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com>
This commit is contained in:
parent
173b356abf
commit
f36292dbee
@ -478,10 +478,11 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/compile
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
|
# fp8 kv scales not supported on sm89, tested on Blackwell instead
|
||||||
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
|
- pytest -v -s compile/test_full_graph.py -k 'not test_fp8_kv_scale_compile'
|
||||||
# Limit to no custom ops to reduce running time
|
# Limit to no custom ops to reduce running time
|
||||||
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
# Wrap with quotes to escape yaml and avoid starting -k string with a -
|
||||||
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and -quant_fp8'"
|
- "pytest -v -s compile/test_fusions_e2e.py -k 'TRITON and not +quant_fp8 and not Llama-4'"
|
||||||
|
|
||||||
- label: Cudagraph test
|
- label: Cudagraph test
|
||||||
timeout_in_minutes: 20
|
timeout_in_minutes: 20
|
||||||
@ -925,7 +926,7 @@ steps:
|
|||||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||||
|
|
||||||
- label: Blackwell Fusion Tests # 30 min
|
- label: Blackwell Fusion & Compile Tests # 30 min
|
||||||
timeout_in_minutes: 40
|
timeout_in_minutes: 40
|
||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
gpu: b200
|
gpu: b200
|
||||||
@ -946,7 +947,9 @@ steps:
|
|||||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||||
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
# Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time
|
||||||
# Wrap with quotes to escape yaml
|
# Wrap with quotes to escape yaml
|
||||||
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and Llama-3.1 and -quant_fp8 and -rms_norm'"
|
- "pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'"
|
||||||
|
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
||||||
|
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
||||||
|
|
||||||
- label: Blackwell Fusion E2E Tests # 30 min
|
- label: Blackwell Fusion E2E Tests # 30 min
|
||||||
timeout_in_minutes: 40
|
timeout_in_minutes: 40
|
||||||
@ -969,8 +972,6 @@ steps:
|
|||||||
- nvidia-smi
|
- nvidia-smi
|
||||||
# Run all e2e fusion tests
|
# Run all e2e fusion tests
|
||||||
- pytest -v -s tests/compile/test_fusions_e2e.py
|
- pytest -v -s tests/compile/test_fusions_e2e.py
|
||||||
# test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40)
|
|
||||||
- pytest -v -s tests/compile/test_full_graph.py::test_fp8_kv_scale_compile
|
|
||||||
|
|
||||||
- label: Blackwell GPT-OSS Eval
|
- label: Blackwell GPT-OSS Eval
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
@ -1266,7 +1267,8 @@ steps:
|
|||||||
- pytest -v -s tests/compile/test_async_tp.py
|
- pytest -v -s tests/compile/test_async_tp.py
|
||||||
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
- "pytest -v -s tests/compile/test_fusions_e2e.py -k 'not Llama-4'"
|
||||||
|
- pytest -v -s tests/distributed/test_sequence_parallel.py
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||||
- pytest -v -s tests/v1/distributed/test_dbo.py
|
- pytest -v -s tests/v1/distributed/test_dbo.py
|
||||||
|
|||||||
@ -20,13 +20,22 @@ from vllm.utils.torch_utils import is_torch_equal_or_newer
|
|||||||
|
|
||||||
from ..utils import flat_product, multi_gpu_test
|
from ..utils import flat_product, multi_gpu_test
|
||||||
|
|
||||||
|
is_blackwell = lambda: current_platform.is_device_capability(100)
|
||||||
|
"""Are we running on Blackwell, a lot of tests depend on it"""
|
||||||
|
|
||||||
|
|
||||||
|
class Matches(NamedTuple):
|
||||||
|
attention_fusion: int = 0
|
||||||
|
allreduce_fusion: int = 0
|
||||||
|
sequence_parallel: int = 0
|
||||||
|
async_tp: int = 0
|
||||||
|
|
||||||
|
|
||||||
class ModelBackendTestCase(NamedTuple):
|
class ModelBackendTestCase(NamedTuple):
|
||||||
model_name: str
|
model_name: str
|
||||||
model_kwargs: dict[str, Any]
|
model_kwargs: dict[str, Any]
|
||||||
backend: AttentionBackendEnum
|
backend: AttentionBackendEnum
|
||||||
attention_fusions: int
|
matches: Matches
|
||||||
allreduce_fusions: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||||
@ -38,17 +47,33 @@ if current_platform.is_cuda():
|
|||||||
ModelBackendTestCase(
|
ModelBackendTestCase(
|
||||||
# Use smaller model for L40s in CI
|
# Use smaller model for L40s in CI
|
||||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||||
model_kwargs=dict(max_model_len=1024),
|
# TODO while llama4 is broken, use FLASHINFER for llama3 on Blackwell
|
||||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
# so FI attention+fp8_quant is at least tested once
|
||||||
attention_fusions=32,
|
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||||
allreduce_fusions=65,
|
backend=AttentionBackendEnum.FLASHINFER
|
||||||
|
if is_blackwell()
|
||||||
|
else AttentionBackendEnum.TRITON_ATTN,
|
||||||
|
matches=Matches(
|
||||||
|
attention_fusion=32,
|
||||||
|
allreduce_fusion=65,
|
||||||
|
sequence_parallel=65,
|
||||||
|
async_tp=128,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
ModelBackendTestCase(
|
ModelBackendTestCase(
|
||||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||||
backend=AttentionBackendEnum.FLASHINFER,
|
# TODO FlashInfer attn broken on Hopper with kvcache=fp8:
|
||||||
attention_fusions=48,
|
# https://github.com/vllm-project/vllm/issues/28568
|
||||||
allreduce_fusions=96,
|
# TODO FlashInfer attn broken on Blackwell for llama4:
|
||||||
|
# https://github.com/vllm-project/vllm/issues/28604
|
||||||
|
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||||
|
matches=Matches(
|
||||||
|
attention_fusion=48,
|
||||||
|
allreduce_fusion=96,
|
||||||
|
sequence_parallel=96,
|
||||||
|
async_tp=95, # mlp is moe, no fusion there
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -57,8 +82,12 @@ if current_platform.is_cuda():
|
|||||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||||
backend=AttentionBackendEnum.FLASHINFER,
|
backend=AttentionBackendEnum.FLASHINFER,
|
||||||
attention_fusions=32,
|
matches=Matches(
|
||||||
allreduce_fusions=65,
|
attention_fusion=32,
|
||||||
|
allreduce_fusion=65,
|
||||||
|
sequence_parallel=65,
|
||||||
|
async_tp=128,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -68,15 +97,23 @@ if current_platform.is_cuda():
|
|||||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
model_kwargs=dict(max_model_len=1024),
|
model_kwargs=dict(max_model_len=1024),
|
||||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||||
attention_fusions=0,
|
matches=Matches(
|
||||||
allreduce_fusions=65,
|
attention_fusion=0,
|
||||||
|
allreduce_fusion=65,
|
||||||
|
sequence_parallel=65,
|
||||||
|
async_tp=128,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
ModelBackendTestCase(
|
ModelBackendTestCase(
|
||||||
model_name="Qwen/Qwen3-30B-A3B",
|
model_name="Qwen/Qwen3-30B-A3B",
|
||||||
model_kwargs=dict(max_model_len=1024),
|
model_kwargs=dict(max_model_len=1024),
|
||||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||||
attention_fusions=0,
|
matches=Matches(
|
||||||
allreduce_fusions=97,
|
attention_fusion=0,
|
||||||
|
allreduce_fusion=97,
|
||||||
|
sequence_parallel=97,
|
||||||
|
async_tp=96, # MLP is MoE, half the fusions of dense
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -86,19 +123,19 @@ elif current_platform.is_rocm():
|
|||||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||||
model_kwargs=dict(max_model_len=1024),
|
model_kwargs=dict(max_model_len=1024),
|
||||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||||
attention_fusions=32,
|
matches=Matches(attention_fusion=32),
|
||||||
),
|
),
|
||||||
ModelBackendTestCase(
|
ModelBackendTestCase(
|
||||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||||
model_kwargs=dict(max_model_len=1024),
|
model_kwargs=dict(max_model_len=1024),
|
||||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||||
attention_fusions=32,
|
matches=Matches(attention_fusion=32),
|
||||||
),
|
),
|
||||||
ModelBackendTestCase(
|
ModelBackendTestCase(
|
||||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||||
model_kwargs=dict(max_model_len=1024),
|
model_kwargs=dict(max_model_len=1024),
|
||||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||||
attention_fusions=32,
|
matches=Matches(attention_fusion=32),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -106,8 +143,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name, model_kwargs, backend, "
|
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||||
"attention_fusions, allreduce_fusions, custom_ops",
|
|
||||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||||
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
||||||
# quant_fp4 only has the custom impl
|
# quant_fp4 only has the custom impl
|
||||||
@ -118,15 +154,14 @@ def test_attn_quant(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_kwargs: dict[str, Any],
|
model_kwargs: dict[str, Any],
|
||||||
backend: AttentionBackendEnum,
|
backend: AttentionBackendEnum,
|
||||||
attention_fusions: int,
|
matches: Matches,
|
||||||
allreduce_fusions: int,
|
|
||||||
custom_ops: str,
|
custom_ops: str,
|
||||||
inductor_graph_partition: bool,
|
inductor_graph_partition: bool,
|
||||||
caplog_mp_spawn,
|
caplog_mp_spawn,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
):
|
):
|
||||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
not is_blackwell() or not has_flashinfer()
|
||||||
):
|
):
|
||||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
@ -169,12 +204,12 @@ def test_attn_quant(
|
|||||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||||
run_model(compilation_config, model_name, **model_kwargs)
|
run_model(compilation_config, model_name, **model_kwargs)
|
||||||
|
|
||||||
matches = re.findall(
|
log_matches = re.findall(
|
||||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||||
log_holder.text,
|
log_holder.text,
|
||||||
)
|
)
|
||||||
assert len(matches) == 1, log_holder.text
|
assert len(log_matches) == 1, log_holder.text
|
||||||
assert int(matches[0]) == attention_fusions
|
assert int(log_matches[0]) == matches.attention_fusion
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
|
CUSTOM_OPS_RMS_NORM = ["-rms_norm", "+rms_norm"]
|
||||||
@ -187,8 +222,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
|||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name, model_kwargs, backend, "
|
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||||
"attention_fusions, allreduce_fusions, custom_ops",
|
|
||||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||||
list(
|
list(
|
||||||
flat_product(
|
flat_product(
|
||||||
@ -209,8 +243,7 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_kwargs: dict,
|
model_kwargs: dict,
|
||||||
backend: AttentionBackendEnum,
|
backend: AttentionBackendEnum,
|
||||||
attention_fusions: int,
|
matches: Matches,
|
||||||
allreduce_fusions: int,
|
|
||||||
custom_ops: str,
|
custom_ops: str,
|
||||||
inductor_graph_partition: bool,
|
inductor_graph_partition: bool,
|
||||||
caplog_mp_spawn,
|
caplog_mp_spawn,
|
||||||
@ -219,6 +252,13 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
|||||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||||
|
|
||||||
|
if "fp4" in model_name.lower() and not is_blackwell():
|
||||||
|
pytest.skip("NVFP4 quant requires Blackwell")
|
||||||
|
|
||||||
|
if backend == AttentionBackendEnum.FLASHINFER and not is_blackwell():
|
||||||
|
# FlashInfer attn fusion requires Blackwell
|
||||||
|
matches = matches._replace(attention_fusion=0)
|
||||||
|
|
||||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||||
|
|
||||||
if inductor_graph_partition:
|
if inductor_graph_partition:
|
||||||
@ -258,23 +298,135 @@ def test_tp2_attn_quant_allreduce_rmsnorm(
|
|||||||
run_model(
|
run_model(
|
||||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||||
)
|
)
|
||||||
matches = re.findall(
|
log_matches = re.findall(
|
||||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||||
log_holder.text,
|
log_holder.text,
|
||||||
)
|
)
|
||||||
assert len(matches) == 2, log_holder.text
|
assert len(log_matches) == 2, log_holder.text
|
||||||
|
|
||||||
assert int(matches[0]) == attention_fusions
|
assert int(log_matches[0]) == matches.attention_fusion
|
||||||
assert int(matches[1]) == attention_fusions
|
assert int(log_matches[1]) == matches.attention_fusion
|
||||||
|
|
||||||
matches = re.findall(
|
log_matches = re.findall(
|
||||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||||
log_holder.text,
|
log_holder.text,
|
||||||
)
|
)
|
||||||
assert len(matches) == 2, log_holder.text
|
assert len(log_matches) == 2, log_holder.text
|
||||||
|
|
||||||
assert int(matches[0]) == allreduce_fusions
|
assert int(log_matches[0]) == matches.allreduce_fusion
|
||||||
assert int(matches[1]) == allreduce_fusions
|
assert int(log_matches[1]) == matches.allreduce_fusion
|
||||||
|
|
||||||
|
|
||||||
|
@multi_gpu_test(num_gpus=2)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name, model_kwargs, backend, matches, custom_ops",
|
||||||
|
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||||
|
list(
|
||||||
|
flat_product(
|
||||||
|
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Toggle RMSNorm for FP4 models and unquant models
|
||||||
|
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda(),
|
||||||
|
reason="sequence parallel only tested on CUDA",
|
||||||
|
)
|
||||||
|
def test_tp2_attn_quant_async_tp(
|
||||||
|
model_name: str,
|
||||||
|
model_kwargs: dict,
|
||||||
|
backend: AttentionBackendEnum,
|
||||||
|
matches: Matches,
|
||||||
|
custom_ops: str,
|
||||||
|
inductor_graph_partition: bool,
|
||||||
|
caplog_mp_spawn,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
if is_blackwell():
|
||||||
|
# TODO: https://github.com/vllm-project/vllm/issues/27893
|
||||||
|
pytest.skip("Blackwell is not supported for AsyncTP pass")
|
||||||
|
|
||||||
|
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
|
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||||
|
|
||||||
|
if "fp4" in model_name.lower() and not is_blackwell():
|
||||||
|
pytest.skip("NVFP4 quant requires Blackwell")
|
||||||
|
|
||||||
|
if backend == AttentionBackendEnum.FLASHINFER:
|
||||||
|
if not has_flashinfer():
|
||||||
|
pytest.skip("FlashInfer backend requires flashinfer installed")
|
||||||
|
if not is_blackwell():
|
||||||
|
# FlashInfer attn fusion requires Blackwell
|
||||||
|
matches = matches._replace(attention_fusion=0)
|
||||||
|
|
||||||
|
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||||
|
|
||||||
|
if inductor_graph_partition:
|
||||||
|
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||||
|
splitting_ops: list[str] | None = None
|
||||||
|
else:
|
||||||
|
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
|
splitting_ops = []
|
||||||
|
|
||||||
|
# Disable, compile cache to make sure custom passes run.
|
||||||
|
# Otherwise, we can't verify fusion happened through the logs.
|
||||||
|
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||||
|
|
||||||
|
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||||
|
# Force spawn as it is more general.
|
||||||
|
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||||
|
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||||
|
|
||||||
|
compilation_config = CompilationConfig(
|
||||||
|
# Testing properties
|
||||||
|
use_inductor_graph_partition=inductor_graph_partition,
|
||||||
|
cudagraph_mode=mode,
|
||||||
|
custom_ops=custom_ops_list,
|
||||||
|
splitting_ops=splitting_ops,
|
||||||
|
# Common
|
||||||
|
level=CompilationMode.VLLM_COMPILE,
|
||||||
|
pass_config=PassConfig(
|
||||||
|
enable_attn_fusion=True,
|
||||||
|
enable_noop=True,
|
||||||
|
enable_sequence_parallelism=True,
|
||||||
|
enable_async_tp=True,
|
||||||
|
),
|
||||||
|
# Inductor caches custom passes by default as well via uuid
|
||||||
|
inductor_compile_config={"force_disable_caches": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||||
|
run_model(
|
||||||
|
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||||
|
)
|
||||||
|
log_matches = re.findall(
|
||||||
|
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||||
|
log_holder.text,
|
||||||
|
)
|
||||||
|
assert len(log_matches) == 2, log_holder.text
|
||||||
|
|
||||||
|
assert int(log_matches[0]) == matches.attention_fusion
|
||||||
|
assert int(log_matches[1]) == matches.attention_fusion
|
||||||
|
|
||||||
|
log_matches = re.findall(
|
||||||
|
r"sequence_parallelism.py:\d+] Replaced (\d+) patterns",
|
||||||
|
log_holder.text,
|
||||||
|
)
|
||||||
|
assert len(log_matches) == 2, log_holder.text
|
||||||
|
|
||||||
|
assert int(log_matches[0]) == matches.sequence_parallel
|
||||||
|
assert int(log_matches[1]) == matches.sequence_parallel
|
||||||
|
|
||||||
|
log_matches = re.findall(
|
||||||
|
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||||
|
log_holder.text,
|
||||||
|
)
|
||||||
|
assert len(log_matches) == 2, log_holder.text
|
||||||
|
|
||||||
|
assert int(log_matches[0]) == matches.async_tp
|
||||||
|
assert int(log_matches[1]) == matches.async_tp
|
||||||
|
|
||||||
|
|
||||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||||
|
|||||||
@ -5,15 +5,15 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
|
||||||
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
from vllm.compilation.fx_utils import find_auto_fn
|
||||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||||
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
|
||||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
|
CUDAGraphMode,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
PassConfig,
|
PassConfig,
|
||||||
@ -27,6 +27,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.system_utils import update_environment_variables
|
from vllm.utils.system_utils import update_environment_variables
|
||||||
@ -43,172 +44,157 @@ prompts = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
def __init__(self, hidden_size=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size
|
self.eps = eps
|
||||||
self.gate_proj = torch.nn.Parameter(
|
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||||
torch.empty((intermediate_size, hidden_size))
|
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||||
)
|
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
|
||||||
# Initialize weights
|
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, x):
|
||||||
"""
|
z = torch.relu(x)
|
||||||
Forward pass implementing the operations in the FX graph
|
x = resid = tensor_model_parallel_all_reduce(z)
|
||||||
|
y = self.norm[0](x)
|
||||||
|
|
||||||
Args:
|
z2 = torch.mm(y, self.w[0])
|
||||||
hidden_states: Input tensor
|
x2 = tensor_model_parallel_all_reduce(z2)
|
||||||
residual: Residual tensor from previous layer
|
|
||||||
|
|
||||||
Returns:
|
y2, resid = self.norm[1](x2, resid)
|
||||||
Tuple containing the output tensor
|
|
||||||
"""
|
|
||||||
# Reshape input
|
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
|
||||||
|
|
||||||
# matrix multiplication
|
z3 = torch.mm(y2, self.w[1])
|
||||||
permute = self.gate_proj.permute(1, 0)
|
x3 = tensor_model_parallel_all_reduce(z3)
|
||||||
mm = torch.mm(view, permute)
|
|
||||||
|
|
||||||
# Tensor parallel all-reduce
|
y3, resid = self.norm[2](x3, resid)
|
||||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
|
||||||
|
|
||||||
# layer normalization
|
z4 = torch.mm(y3, self.w[2])
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
x4 = tensor_model_parallel_all_reduce(z4)
|
||||||
|
|
||||||
return norm_output, residual_output
|
y4, resid = self.norm[3](x4, resid)
|
||||||
|
return y4
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
return [torch.ops.vllm.all_reduce.default]
|
return [torch.ops.vllm.all_reduce.default]
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
return [
|
return [
|
||||||
torch.ops.vllm.reduce_scatter.default,
|
|
||||||
torch.ops.vllm.all_gather.default,
|
torch.ops.vllm.all_gather.default,
|
||||||
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
return [torch.ops._C.fused_add_rms_norm.default]
|
if RMSNorm.enabled():
|
||||||
|
return [
|
||||||
|
torch.ops._C.rms_norm.default,
|
||||||
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class TestQuantModel(torch.nn.Module):
|
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
def __init__(self, hidden_size=16, eps=1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.intermediate_size = intermediate_size
|
|
||||||
self.vllm_config = get_current_vllm_config()
|
self.vllm_config = get_current_vllm_config()
|
||||||
self.gate_proj = torch.nn.Parameter(
|
self.hidden_size = hidden_size
|
||||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
self.eps = eps
|
||||||
)
|
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||||
# Initialize weights
|
self.w = [
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.rand(hidden_size, hidden_size)
|
||||||
|
.to(dtype=current_platform.fp8_dtype())
|
||||||
|
.t()
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
self.fp8_linear = Fp8LinearOp(
|
||||||
|
act_quant_static=True,
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||||
# Create a weight that is compatible with torch._scaled_mm,
|
|
||||||
# which expects a column-major layout.
|
|
||||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
|
||||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
|
||||||
"""
|
|
||||||
Forward pass implementing the operations in the FX graph
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_states: Input tensor
|
|
||||||
residual: Residual tensor from previous layer
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple containing the output tensor
|
|
||||||
"""
|
|
||||||
# Reshape input
|
|
||||||
view = hidden_states.reshape(-1, self.hidden_size)
|
|
||||||
|
|
||||||
# matrix multiplication
|
|
||||||
permute = self.gate_proj.permute(1, 0)
|
|
||||||
mm = torch.mm(view, permute)
|
|
||||||
|
|
||||||
# Tensor parallel all-reduce
|
|
||||||
all_reduce = tensor_model_parallel_all_reduce(mm)
|
|
||||||
|
|
||||||
# layer normalization
|
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
|
||||||
|
|
||||||
# scaled_mm with static input quantization
|
|
||||||
fp8_linear_result = self.fp8_linear.apply(
|
|
||||||
norm_output,
|
|
||||||
self.w,
|
|
||||||
self.wscale,
|
|
||||||
input_scale=self.scale.to(norm_output.device),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return fp8_linear_result, residual_output
|
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def forward(self, hidden_states):
|
||||||
ops_to_remove = [torch.ops.vllm.all_reduce.default] # Always removed by SP
|
# avoid having graph input be an arg to a pattern directly
|
||||||
# The following are only removed if fusion happens
|
z = torch.relu(hidden_states)
|
||||||
if (
|
x = resid = tensor_model_parallel_all_reduce(z)
|
||||||
self.vllm_config
|
y = self.norm[0](x)
|
||||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
|
||||||
):
|
z2 = self.fp8_linear.apply(
|
||||||
ops_to_remove.extend(
|
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||||
[
|
)
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
|
||||||
torch.ops._C.static_scaled_fp8_quant.default,
|
x2 = tensor_model_parallel_all_reduce(z2)
|
||||||
]
|
y2, resid = self.norm[1](x2, resid)
|
||||||
)
|
|
||||||
return ops_to_remove
|
z3 = self.fp8_linear.apply(
|
||||||
|
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
x3 = tensor_model_parallel_all_reduce(z3)
|
||||||
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
|
|
||||||
|
z4 = self.fp8_linear.apply(
|
||||||
|
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||||
|
)
|
||||||
|
x4 = tensor_model_parallel_all_reduce(z4)
|
||||||
|
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||||
|
return y4
|
||||||
|
|
||||||
def ops_in_model_after(self):
|
def ops_in_model_after(self):
|
||||||
ops_to_add = [
|
return [
|
||||||
torch.ops.vllm.reduce_scatter.default,
|
|
||||||
torch.ops.vllm.all_gather.default,
|
torch.ops.vllm.all_gather.default,
|
||||||
|
torch.ops.vllm.reduce_scatter.default,
|
||||||
|
]
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.all_reduce.default,
|
||||||
]
|
]
|
||||||
# The following is only added if fusion happens
|
|
||||||
if (
|
|
||||||
self.vllm_config
|
|
||||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
|
||||||
):
|
|
||||||
ops_to_add.append(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default)
|
|
||||||
return ops_to_add
|
|
||||||
|
|
||||||
def ops_in_model(self):
|
def ops_in_model(self):
|
||||||
if (
|
if self.vllm_config.compilation_config.pass_config.enable_fusion:
|
||||||
self.vllm_config
|
|
||||||
and self.vllm_config.compilation_config.pass_config.enable_fusion
|
|
||||||
):
|
|
||||||
# If fusion happens, the fused op is the one
|
|
||||||
# we check for (de)functionalization
|
|
||||||
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
return [torch.ops._C.fused_add_rms_norm_static_fp8_quant.default]
|
||||||
else:
|
elif RMSNorm.enabled():
|
||||||
# If no fusion, the original ops are checked
|
|
||||||
return [
|
return [
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
torch.ops._C.fused_add_rms_norm.default,
|
||||||
# TODO functionalization pass does not handle this yet
|
|
||||||
# torch.ops._C.static_scaled_fp8_quant.default,
|
|
||||||
]
|
]
|
||||||
|
elif self.fp8_linear.quant_fp8.enabled():
|
||||||
|
return [
|
||||||
|
torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
@multi_gpu_test(num_gpus=2)
|
@multi_gpu_test(num_gpus=2)
|
||||||
@pytest.mark.parametrize("test_model_cls", [TestModel, TestQuantModel])
|
@pytest.mark.parametrize(
|
||||||
|
"test_model_cls, custom_ops",
|
||||||
|
[
|
||||||
|
(TestAllReduceRMSNormModel, "+rms_norm"),
|
||||||
|
(TestAllReduceRMSNormModel, "-rms_norm"),
|
||||||
|
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,+quant_fp8"),
|
||||||
|
(TestAllReduceRMSNormStaticQuantFP8Model, "+rms_norm,-quant_fp8"),
|
||||||
|
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,+quant_fp8"),
|
||||||
|
(TestAllReduceRMSNormStaticQuantFP8Model, "-rms_norm,-quant_fp8"),
|
||||||
|
],
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("seq_len", [16])
|
@pytest.mark.parametrize("seq_len", [16])
|
||||||
@pytest.mark.parametrize("hidden_size", [16])
|
@pytest.mark.parametrize("hidden_size", [16])
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("enable_fusion", [True, False])
|
@pytest.mark.parametrize("enable_fusion", [True, False])
|
||||||
|
@pytest.mark.parametrize("dynamic", [False, True])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||||
def test_sequence_parallelism_pass(
|
def test_sequence_parallelism_pass(
|
||||||
test_model_cls: type[torch.nn.Module],
|
test_model_cls: type[torch.nn.Module],
|
||||||
|
custom_ops: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
enable_fusion: bool,
|
enable_fusion: bool,
|
||||||
|
dynamic: bool,
|
||||||
):
|
):
|
||||||
num_processes = 2
|
num_processes = 2
|
||||||
|
|
||||||
@ -220,11 +206,13 @@ def test_sequence_parallelism_pass(
|
|||||||
args=(
|
args=(
|
||||||
num_processes,
|
num_processes,
|
||||||
test_model_cls,
|
test_model_cls,
|
||||||
|
custom_ops,
|
||||||
batch_size,
|
batch_size,
|
||||||
seq_len,
|
seq_len,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
dtype,
|
dtype,
|
||||||
enable_fusion,
|
enable_fusion,
|
||||||
|
dynamic,
|
||||||
),
|
),
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
)
|
)
|
||||||
@ -236,11 +224,13 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
local_rank: int,
|
local_rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
test_model_cls: type[torch.nn.Module],
|
test_model_cls: type[torch.nn.Module],
|
||||||
|
custom_ops: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
enable_fusion: bool,
|
enable_fusion: bool,
|
||||||
|
dynamic: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
|
|
||||||
@ -264,12 +254,16 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||||
|
|
||||||
# configure vllm config for SequenceParallelismPass
|
# configure vllm config for SequenceParallelismPass
|
||||||
|
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||||
compilation_config = CompilationConfig(
|
compilation_config = CompilationConfig(
|
||||||
|
splitting_ops=[], # avoid automatic rms_norm enablement
|
||||||
|
cudagraph_mode=CUDAGraphMode.NONE, # avoid piecewise warnings
|
||||||
|
custom_ops=custom_ops_list,
|
||||||
pass_config=PassConfig(
|
pass_config=PassConfig(
|
||||||
enable_sequence_parallelism=True,
|
enable_sequence_parallelism=True,
|
||||||
enable_fusion=enable_fusion,
|
enable_fusion=enable_fusion,
|
||||||
enable_noop=True,
|
enable_noop=True,
|
||||||
)
|
),
|
||||||
) # NoOp needed for fusion
|
) # NoOp needed for fusion
|
||||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||||
|
|
||||||
@ -289,7 +283,6 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||||
func_pass = FixFunctionalizationPass(vllm_config)
|
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
assert (
|
assert (
|
||||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||||
@ -310,38 +303,29 @@ def sequence_parallelism_pass_on_test_model(
|
|||||||
|
|
||||||
passes_for_backend.append(cleanup_pass)
|
passes_for_backend.append(cleanup_pass)
|
||||||
|
|
||||||
backend_no_func = TestBackend(*passes_for_backend)
|
backend = TestBackend(*passes_for_backend)
|
||||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
|
||||||
|
|
||||||
model = test_model_cls(hidden_size, hidden_size * 2)
|
model = test_model_cls(hidden_size)
|
||||||
|
|
||||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
|
||||||
|
|
||||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
if dynamic:
|
||||||
compiled_model_no_func(hidden_states, residual)
|
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
|
||||||
compiled_model_func(hidden_states, residual)
|
|
||||||
|
|
||||||
assert sequence_parallelism_pass.matched_count == 1
|
compiled_model = torch.compile(model, backend=backend)
|
||||||
|
compiled_model(hidden_states)
|
||||||
|
|
||||||
|
assert sequence_parallelism_pass.matched_count == 4
|
||||||
|
|
||||||
# In pre-nodes, all reduce should be there,
|
# In pre-nodes, all reduce should be there,
|
||||||
# reduce scatter and all gather should not
|
# reduce scatter and all gather should not
|
||||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
for op in model.ops_in_model_before():
|
||||||
|
assert backend.op_count(op, before=True) == 4
|
||||||
|
|
||||||
# In post-nodes, reduce scatter and all gather should be there,
|
# In post-nodes, reduce scatter and all gather should be there,
|
||||||
# all reduce should not
|
# all reduce should not
|
||||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
for op in model.ops_in_model_after():
|
||||||
|
assert backend.op_count(op, before=False) == 4
|
||||||
|
|
||||||
# check if the functionalization pass is applied
|
|
||||||
for op in model.ops_in_model():
|
for op in model.ops_in_model():
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend.graph_post_pass.nodes, op)
|
||||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
|
||||||
|
|
||||||
# make sure the ops were all de-functionalized
|
|
||||||
found = dict()
|
|
||||||
for node in backend_func.graph_post_pass.nodes:
|
|
||||||
for op in model.ops_in_model():
|
|
||||||
if is_func(node, op):
|
|
||||||
found[op] = True
|
|
||||||
assert all(found[op] for op in model.ops_in_model())
|
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import pytest
|
|||||||
from vllm.config.compilation import CompilationMode
|
from vllm.config.compilation import CompilationMode
|
||||||
from vllm.config.model import RunnerOption
|
from vllm.config.model import RunnerOption
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
from ..models.registry import HF_EXAMPLE_MODELS
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
@ -161,6 +162,7 @@ def _compare_sp(
|
|||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available: int,
|
num_gpus_available: int,
|
||||||
use_inductor_graph_partition: bool,
|
use_inductor_graph_partition: bool,
|
||||||
|
enable_async_tp: bool,
|
||||||
*,
|
*,
|
||||||
method: Literal["generate", "encode"],
|
method: Literal["generate", "encode"],
|
||||||
is_multimodal: bool,
|
is_multimodal: bool,
|
||||||
@ -244,10 +246,10 @@ def _compare_sp(
|
|||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
"mode": CompilationMode.VLLM_COMPILE,
|
"mode": CompilationMode.VLLM_COMPILE,
|
||||||
"custom_ops": ["+rms_norm"],
|
|
||||||
"compile_sizes": [4, 8],
|
"compile_sizes": [4, 8],
|
||||||
"pass_config": {
|
"pass_config": {
|
||||||
"enable_sequence_parallelism": True,
|
"enable_sequence_parallelism": True,
|
||||||
|
"enable_async_tp": enable_async_tp,
|
||||||
"enable_fusion": enable_fusion,
|
"enable_fusion": enable_fusion,
|
||||||
"enable_noop": True,
|
"enable_noop": True,
|
||||||
},
|
},
|
||||||
@ -307,6 +309,7 @@ SP_TEST_MODELS = [
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
|
||||||
|
@pytest.mark.parametrize("enable_async_tp", [False]) # TODO: enable async TP
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_tp_sp_generation(
|
def test_tp_sp_generation(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
@ -316,10 +319,19 @@ def test_tp_sp_generation(
|
|||||||
test_options: SPTestOptions,
|
test_options: SPTestOptions,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
use_inductor_graph_partition: bool,
|
use_inductor_graph_partition: bool,
|
||||||
|
enable_async_tp: bool,
|
||||||
):
|
):
|
||||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||||
|
|
||||||
|
# Skip FP8 SP-only test on sm89 (compute capability 8.9)
|
||||||
|
if (
|
||||||
|
"fp8" in model_id.lower()
|
||||||
|
and current_platform.get_device_capability() < (9, 0)
|
||||||
|
and (not enable_async_tp)
|
||||||
|
):
|
||||||
|
pytest.skip("FP8 reduction support begins with sm90 capable devices.")
|
||||||
|
|
||||||
_compare_sp(
|
_compare_sp(
|
||||||
model_id,
|
model_id,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
@ -328,6 +340,7 @@ def test_tp_sp_generation(
|
|||||||
test_options,
|
test_options,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
use_inductor_graph_partition,
|
use_inductor_graph_partition,
|
||||||
|
enable_async_tp=enable_async_tp,
|
||||||
method="generate",
|
method="generate",
|
||||||
is_multimodal=False,
|
is_multimodal=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
@ -10,98 +12,28 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
kFp8StaticTensorSym,
|
||||||
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .inductor_pass import enable_fake_mode
|
from .inductor_pass import enable_fake_mode
|
||||||
|
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||||
|
from .noop_elimination import NoOpEliminationPass
|
||||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class _RMSNormAndQuantOpHelper:
|
def get_first_out_wrapper(fn):
|
||||||
"""Base helper for RMSNorm and RMSNorm + Quantization functionalization."""
|
@functools.wraps(fn)
|
||||||
|
def wrapper(*args):
|
||||||
|
return fn(*args)[0]
|
||||||
|
|
||||||
def __init__(
|
return wrapper
|
||||||
self,
|
|
||||||
epsilon: float,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: str,
|
|
||||||
quant_op: torch._ops.OpOverload | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.epsilon = epsilon
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
self.quant_op = quant_op
|
|
||||||
|
|
||||||
def _functional_rmsnorm(self, result_buffer, input_tensor, weight_tensor):
|
|
||||||
return torch.ops.higher_order.auto_functionalized(
|
|
||||||
torch.ops._C.rms_norm.default,
|
|
||||||
result=result_buffer,
|
|
||||||
input=input_tensor,
|
|
||||||
weight=weight_tensor,
|
|
||||||
epsilon=self.epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _functional_fused_add_rmsnorm(
|
|
||||||
self, input_tensor, residual_tensor, weight_tensor
|
|
||||||
):
|
|
||||||
return torch.ops.higher_order.auto_functionalized(
|
|
||||||
torch.ops._C.fused_add_rms_norm.default,
|
|
||||||
input=input_tensor,
|
|
||||||
residual=residual_tensor,
|
|
||||||
weight=weight_tensor,
|
|
||||||
epsilon=self.epsilon,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _functional_rmsnorm_then_quant(
|
|
||||||
self,
|
|
||||||
rmsnorm_result_buffer,
|
|
||||||
quant_result_buffer,
|
|
||||||
input_tensor,
|
|
||||||
weight_tensor,
|
|
||||||
scale_tensor,
|
|
||||||
):
|
|
||||||
if self.quant_op is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
|
||||||
)
|
|
||||||
rmsnorm_out_tuple = self._functional_rmsnorm(
|
|
||||||
rmsnorm_result_buffer, input_tensor, weight_tensor
|
|
||||||
)
|
|
||||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
|
||||||
self.quant_op,
|
|
||||||
result=quant_result_buffer,
|
|
||||||
input=rmsnorm_out_tuple[1],
|
|
||||||
scale=scale_tensor,
|
|
||||||
)
|
|
||||||
return quant_out_tuple
|
|
||||||
|
|
||||||
def _functional_fused_add_rmsnorm_then_quant(
|
|
||||||
self,
|
|
||||||
quant_result_buffer,
|
|
||||||
input_tensor,
|
|
||||||
residual_tensor,
|
|
||||||
weight_tensor,
|
|
||||||
scale_tensor,
|
|
||||||
):
|
|
||||||
if self.quant_op is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"_RMSNormAndQuantOpHelper was not initialized with a quant_op."
|
|
||||||
)
|
|
||||||
fused_add_rmsnorm_out_tuple = self._functional_fused_add_rmsnorm(
|
|
||||||
input_tensor, residual_tensor, weight_tensor
|
|
||||||
)
|
|
||||||
quant_out_tuple = torch.ops.higher_order.auto_functionalized(
|
|
||||||
self.quant_op,
|
|
||||||
result=quant_result_buffer,
|
|
||||||
input=fused_add_rmsnorm_out_tuple[1],
|
|
||||||
scale=scale_tensor,
|
|
||||||
)
|
|
||||||
return quant_out_tuple, fused_add_rmsnorm_out_tuple[2]
|
|
||||||
|
|
||||||
|
|
||||||
class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
class _SequenceParallelPatternHelper:
|
||||||
"""Helper for sequence parallelism patterns."""
|
"""Helper for sequence parallelism patterns."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -109,10 +41,10 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
|||||||
epsilon: float,
|
epsilon: float,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: str,
|
device: str,
|
||||||
quant_op: torch._ops.OpOverload | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__(epsilon, dtype, device, quant_op=quant_op, **kwargs)
|
self.epsilon = epsilon
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
self.tp_group = get_tp_group()
|
self.tp_group = get_tp_group()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
@ -131,36 +63,34 @@ class _SequenceParallelPatternHelper(_RMSNormAndQuantOpHelper):
|
|||||||
|
|
||||||
|
|
||||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||||
|
super().__init__(epsilon, dtype, device)
|
||||||
|
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||||
permute = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
|
||||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
return [input, permute, arg3_1]
|
return [input, arg3_1]
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
permute: torch.Tensor,
|
|
||||||
arg3_1: torch.Tensor,
|
arg3_1: torch.Tensor,
|
||||||
):
|
):
|
||||||
all_reduce = self._all_reduce(input)
|
all_reduce = self._all_reduce(input)
|
||||||
rmsnorm = self._functional_rmsnorm(permute, all_reduce, arg3_1)
|
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||||
|
|
||||||
return rmsnorm[1], all_reduce
|
return rmsnorm, all_reduce
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
permute: torch.Tensor,
|
|
||||||
arg3_1: torch.Tensor,
|
arg3_1: torch.Tensor,
|
||||||
):
|
):
|
||||||
reduce_scatter = self._reduce_scatter(input)
|
reduce_scatter = self._reduce_scatter(input)
|
||||||
|
|
||||||
rmsnorm_result = torch.empty_like(reduce_scatter)
|
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||||
rmsnorm = self._functional_rmsnorm(rmsnorm_result, reduce_scatter, arg3_1)
|
all_gather = self._all_gather(rmsnorm)
|
||||||
|
|
||||||
all_gather = self._all_gather(rmsnorm[1])
|
|
||||||
|
|
||||||
return all_gather, reduce_scatter
|
return all_gather, reduce_scatter
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
@ -169,6 +99,10 @@ class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
|||||||
|
|
||||||
|
|
||||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||||
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||||
|
super().__init__(epsilon, dtype, device)
|
||||||
|
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@ -188,67 +122,34 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
|||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = self._all_reduce(mm_1)
|
all_reduce = self._all_reduce(mm_1)
|
||||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||||
all_reduce, residual, rms_norm_weights
|
return rmsnorm[0], rmsnorm[1]
|
||||||
)
|
|
||||||
return rmsnorm[1], rmsnorm[2]
|
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# pattern matcher replaces from top-to-bottom,
|
||||||
|
# so residual is still the full size here.
|
||||||
|
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||||
reduce_scatter = self._reduce_scatter(mm_1)
|
reduce_scatter = self._reduce_scatter(mm_1)
|
||||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||||
reduce_scatter, residual, rms_norm_weights
|
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||||
)
|
all_gather = self._all_gather(rmsnorm[0])
|
||||||
all_gather = self._all_gather(rmsnorm[1])
|
# shape of residual changes but that's fine,
|
||||||
return all_gather, rmsnorm[2]
|
# next node is already slicing it, now becomes a noop
|
||||||
|
return all_gather, rmsnorm[1]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LastAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
|
||||||
def get_inputs(self):
|
|
||||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
|
||||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
return [
|
|
||||||
residual,
|
|
||||||
mm_1,
|
|
||||||
rms_norm_weights,
|
|
||||||
]
|
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
|
||||||
def pattern(
|
|
||||||
residual: torch.Tensor,
|
|
||||||
mm_1: torch.Tensor,
|
|
||||||
rms_norm_weights: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
all_reduce = self._all_reduce(mm_1)
|
|
||||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
|
||||||
all_reduce, residual, rms_norm_weights
|
|
||||||
)
|
|
||||||
return rmsnorm[1]
|
|
||||||
|
|
||||||
def replacement(
|
|
||||||
residual: torch.Tensor,
|
|
||||||
mm_1: torch.Tensor,
|
|
||||||
rms_norm_weights: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
reduce_scatter = self._reduce_scatter(mm_1)
|
|
||||||
rmsnorm = self._functional_fused_add_rmsnorm(
|
|
||||||
reduce_scatter, residual, rms_norm_weights
|
|
||||||
)
|
|
||||||
normalized = self._all_gather(rmsnorm[1])
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
get_first_out_wrapper(pattern),
|
||||||
|
get_first_out_wrapper(replacement),
|
||||||
|
self.get_inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -257,52 +158,41 @@ FP8_DTYPE = current_platform.fp8_dtype()
|
|||||||
|
|
||||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
self,
|
||||||
|
epsilon: float,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
):
|
):
|
||||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
super().__init__(epsilon, dtype, device)
|
||||||
|
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||||
|
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||||
rmsnorm_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
|
||||||
quant_result = torch.empty([1, 8, 4], device=self.device, dtype=FP8_DTYPE)
|
|
||||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
return [input, weight, scale]
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
def pattern(
|
def pattern(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
rmsnorm_result: torch.Tensor,
|
|
||||||
quant_result: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
all_reduce = self._all_reduce(input)
|
all_reduce = self._all_reduce(input)
|
||||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||||
rmsnorm_result, quant_result, all_reduce, weight, scale
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
)
|
return quant, all_reduce
|
||||||
return static_fp8[1], all_reduce
|
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
rmsnorm_result: torch.Tensor,
|
|
||||||
quant_result: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
reduce_scatter = self._reduce_scatter(input)
|
reduce_scatter = self._reduce_scatter(input)
|
||||||
|
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||||
rmsnorm_result = torch.empty_like(
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
reduce_scatter, dtype=rmsnorm_result.dtype
|
all_gather = self._all_gather(quant)
|
||||||
)
|
|
||||||
quant_result = torch.empty_like(
|
|
||||||
rmsnorm_result, # Output of RMSNorm
|
|
||||||
dtype=quant_result.dtype,
|
|
||||||
)
|
|
||||||
static_fp8 = self._functional_rmsnorm_then_quant(
|
|
||||||
rmsnorm_result, quant_result, reduce_scatter, weight, scale
|
|
||||||
)
|
|
||||||
all_gather = self._all_gather(static_fp8[1])
|
|
||||||
|
|
||||||
return all_gather, reduce_scatter
|
return all_gather, reduce_scatter
|
||||||
|
|
||||||
@ -312,118 +202,64 @@ class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
|||||||
|
|
||||||
|
|
||||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||||
def __init__(
|
def __init__(self, epsilon: float, dtype: torch.dtype, device: str):
|
||||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
super().__init__(epsilon, dtype, device)
|
||||||
):
|
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||||
|
|
||||||
def get_inputs(self):
|
def get_inputs(self):
|
||||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
|
||||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||||
|
|
||||||
return [
|
return [residual, mm_1, rms_norm_weights, scale]
|
||||||
result,
|
|
||||||
residual,
|
|
||||||
mm_1,
|
|
||||||
rms_norm_weights,
|
|
||||||
scale,
|
|
||||||
]
|
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
def pattern(
|
def pattern(
|
||||||
result: torch.Tensor,
|
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
all_reduce = self._all_reduce(mm_1)
|
all_reduce = self._all_reduce(mm_1)
|
||||||
static_fp8, rmsnorm_residual_out = (
|
rms, residual_out = self.rmsnorm_matcher(
|
||||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
all_reduce, rms_norm_weights, residual
|
||||||
result, all_reduce, residual, rms_norm_weights, scale
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return static_fp8[1], rmsnorm_residual_out
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
|
return quant, residual_out
|
||||||
|
|
||||||
def replacement(
|
def replacement(
|
||||||
result: torch.Tensor,
|
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
mm_1: torch.Tensor,
|
mm_1: torch.Tensor,
|
||||||
rms_norm_weights: torch.Tensor,
|
rms_norm_weights: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# pattern matcher replaces from top-to-bottom,
|
||||||
|
# so residual is still the full size here.
|
||||||
|
# add a temporary slice which will become a noop
|
||||||
|
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||||
reduce_scatter = self._reduce_scatter(mm_1)
|
reduce_scatter = self._reduce_scatter(mm_1)
|
||||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||||
static_fp8, rmsnorm_residual_out = (
|
rms, residual_out = self.rmsnorm_matcher(
|
||||||
self._functional_fused_add_rmsnorm_then_quant( # noqa: E501
|
reduce_scatter, rms_norm_weights, residual
|
||||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
all_gather = self._all_gather(static_fp8[1])
|
quant, _ = self.quant_matcher(rms, scale)
|
||||||
return all_gather, rmsnorm_residual_out
|
all_gather = self._all_gather(quant)
|
||||||
|
# shape of residual changes but that's fine,
|
||||||
|
# next node is already slicing it, now becomes a noop
|
||||||
|
return all_gather, residual_out
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LastAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
|
||||||
def __init__(
|
|
||||||
self, epsilon: float, dtype: torch.dtype, device: str, op: torch._ops.OpOverload
|
|
||||||
):
|
|
||||||
super().__init__(epsilon, dtype, device, quant_op=op)
|
|
||||||
|
|
||||||
def get_inputs(self):
|
|
||||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
|
||||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
|
||||||
result = torch.empty([4, 4], device=self.device, dtype=FP8_DTYPE)
|
|
||||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
return [
|
|
||||||
result,
|
|
||||||
residual,
|
|
||||||
mm_1,
|
|
||||||
rms_norm_weights,
|
|
||||||
scale,
|
|
||||||
]
|
|
||||||
|
|
||||||
def register(self, pm_pass: PatternMatcherPass):
|
|
||||||
def pattern(
|
|
||||||
result: torch.Tensor,
|
|
||||||
residual: torch.Tensor,
|
|
||||||
mm_1: torch.Tensor,
|
|
||||||
rms_norm_weights: torch.Tensor,
|
|
||||||
scale: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
all_reduce = self._all_reduce(mm_1)
|
|
||||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
|
||||||
result, all_reduce, residual, rms_norm_weights, scale
|
|
||||||
)
|
|
||||||
return static_fp8[1]
|
|
||||||
|
|
||||||
def replacement(
|
|
||||||
result: torch.Tensor,
|
|
||||||
residual: torch.Tensor,
|
|
||||||
mm_1: torch.Tensor,
|
|
||||||
rms_norm_weights: torch.Tensor,
|
|
||||||
scale: torch.Tensor,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
reduce_scatter = self._reduce_scatter(mm_1)
|
|
||||||
quant_result_buf = torch.empty_like(reduce_scatter, dtype=result.dtype)
|
|
||||||
static_fp8, _ = self._functional_fused_add_rmsnorm_then_quant(
|
|
||||||
quant_result_buf, reduce_scatter, residual, rms_norm_weights, scale
|
|
||||||
)
|
|
||||||
normalized = self._all_gather(static_fp8[1])
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
get_first_out_wrapper(pattern),
|
||||||
|
get_first_out_wrapper(replacement),
|
||||||
|
self.get_inputs(),
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -445,27 +281,45 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
|||||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||||
significantly reduce communication overhead and improve overall model
|
significantly reduce communication overhead and improve overall model
|
||||||
performance.
|
performance.
|
||||||
|
|
||||||
|
|
||||||
|
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||||
|
Because the pattern matcher starts at the end of the graph, the replacement
|
||||||
|
contains a slice that temporarily conforms the input residual to the correct size.
|
||||||
|
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||||
|
what have now become no-op slices.
|
||||||
|
|
||||||
|
Note that an older version of the pass did not need this as it operated only on
|
||||||
|
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||||
|
mismatched shapes during replacement. So this approach has the same assumption that
|
||||||
|
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||||
|
|
||||||
|
Correctness-wise, this is approach strictly better than before - before,
|
||||||
|
the graph was incorrect semantically and shape-wise during the pass.
|
||||||
|
With this approach there's only semantic incorrectness during the pass.
|
||||||
|
Both approaches restore a correct graph once all patterns are matched.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_fake_mode
|
@enable_fake_mode
|
||||||
def __init__(self, config: VllmConfig):
|
def __init__(self, config: VllmConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
# Used to cleanup redundant views created temporarily
|
||||||
|
# to circumvent residual shape change issues
|
||||||
|
self.noop_cleanup = NoOpEliminationPass(config)
|
||||||
|
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||||
|
|
||||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
pass_name="sequence_parallelism_pass"
|
pass_name="sequence_parallelism_pass"
|
||||||
)
|
)
|
||||||
|
|
||||||
for epsilon in [1e-5, 1e-6]:
|
for epsilon in [1e-5, 1e-6]:
|
||||||
# RMSNorm + Static FP8 quantization patterns
|
# RMSNorm + Static FP8 quantization patterns
|
||||||
fp8_quant_op = torch.ops._C.static_scaled_fp8_quant.default
|
|
||||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
epsilon, self.model_dtype, self.device
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
epsilon, self.model_dtype, self.device
|
||||||
).register(self.patterns)
|
|
||||||
LastAllReduceRMSNormStaticFP8Pattern(
|
|
||||||
epsilon, self.model_dtype, self.device, fp8_quant_op
|
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
|
|
||||||
# Normal RMSNorm patterns
|
# Normal RMSNorm patterns
|
||||||
@ -477,9 +331,6 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
|||||||
epsilon, self.model_dtype, self.device
|
epsilon, self.model_dtype, self.device
|
||||||
).register(self.patterns)
|
).register(self.patterns)
|
||||||
|
|
||||||
LastAllReduceRMSNormPattern(
|
|
||||||
epsilon, self.model_dtype, self.device
|
|
||||||
).register(self.patterns)
|
|
||||||
self.dump_patterns(config, self.patterns)
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
def is_applicable(self, shape: int | None) -> bool:
|
def is_applicable(self, shape: int | None) -> bool:
|
||||||
@ -508,3 +359,5 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
|||||||
def __call__(self, graph: fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
self.matched_count = self.patterns.apply(graph)
|
self.matched_count = self.patterns.apply(graph)
|
||||||
logger.debug("Replaced %s patterns", self.matched_count)
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
# Clean up reshape nodes
|
||||||
|
self.noop_cleanup(graph)
|
||||||
|
|||||||
@ -445,8 +445,6 @@ class VllmConfig:
|
|||||||
# and requires it to be enabled.
|
# and requires it to be enabled.
|
||||||
if self.compilation_config.pass_config.enable_async_tp:
|
if self.compilation_config.pass_config.enable_async_tp:
|
||||||
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
self.compilation_config.pass_config.enable_sequence_parallelism = True
|
||||||
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
|
||||||
self.compilation_config.custom_ops.append("+rms_norm")
|
|
||||||
|
|
||||||
if current_platform.support_static_graph_mode():
|
if current_platform.support_static_graph_mode():
|
||||||
# if cudagraph_mode is not explicitly set by users, set default
|
# if cudagraph_mode is not explicitly set by users, set default
|
||||||
@ -620,6 +618,32 @@ class VllmConfig:
|
|||||||
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||||
self.compilation_config.set_splitting_ops_for_v1()
|
self.compilation_config.set_splitting_ops_for_v1()
|
||||||
|
|
||||||
|
if self.compilation_config.pass_config.enable_sequence_parallelism:
|
||||||
|
# With pipeline parallelism or dynamo partitioning,
|
||||||
|
# native rms norm tracing errors due to incorrect residual shape.
|
||||||
|
# Use custom rms norm to unblock. In the future,
|
||||||
|
# the pass will operate on higher-level IR to avoid the issue.
|
||||||
|
# TODO: https://github.com/vllm-project/vllm/issues/27894
|
||||||
|
is_fullgraph = (
|
||||||
|
self.compilation_config.use_inductor_graph_partition
|
||||||
|
or len(self.compilation_config.splitting_ops) == 0
|
||||||
|
)
|
||||||
|
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
|
||||||
|
if "-rms_norm" not in self.compilation_config.custom_ops:
|
||||||
|
self.compilation_config.custom_ops.append("+rms_norm")
|
||||||
|
else:
|
||||||
|
regime = (
|
||||||
|
"Dynamo partition"
|
||||||
|
if not is_fullgraph
|
||||||
|
else "pipeline parallelism"
|
||||||
|
)
|
||||||
|
logger.warning_once(
|
||||||
|
"Sequence parallelism not supported with"
|
||||||
|
"native rms_norm when using %s, "
|
||||||
|
"this will likely lead to an error.",
|
||||||
|
regime,
|
||||||
|
)
|
||||||
|
|
||||||
# final check of cudagraph mode after all possible updates
|
# final check of cudagraph mode after all possible updates
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
if (
|
if (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user