diff --git a/tests/compile/distributed/test_fusions_e2e.py b/tests/compile/distributed/test_fusions_e2e.py index bd326f1157d8f..80086c4e03a9c 100644 --- a/tests/compile/distributed/test_fusions_e2e.py +++ b/tests/compile/distributed/test_fusions_e2e.py @@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"] list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), ) @pytest.mark.parametrize("inductor_graph_partition", [True, False]) +# TODO: remove skip after we fix the fusion thoroughly +@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell") def test_rms_group_quant( model_name: str, model_kwargs: dict[str, Any], @@ -562,7 +564,7 @@ def test_rms_group_quant( splitting_ops=splitting_ops, # Common mode=CompilationMode.VLLM_COMPILE, - pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), + pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True), # Inductor caches custom passes by default as well via uuid inductor_compile_config={"force_disable_caches": True}, )