diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 9e912c6d810d..8dd6959a01d0 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -392,39 +392,48 @@ def test_pass_config_deprecation(caplog_vllm): assert "enable_fusion is deprecated" in caplog_vllm.text assert config.fuse_norm_quant is True assert config.fuse_act_quant is True - assert config.enable_fusion is None + assert config.enable_fusion is True # Test enable_attn_fusion -> fuse_attn_quant caplog_vllm.clear() config = PassConfig(enable_attn_fusion=True) assert "enable_attn_fusion is deprecated" in caplog_vllm.text assert config.fuse_attn_quant is True - assert config.enable_attn_fusion is None + assert config.enable_attn_fusion is True # Test enable_noop -> eliminate_noops caplog_vllm.clear() config = PassConfig(enable_noop=True) assert "enable_noop is deprecated" in caplog_vllm.text assert config.eliminate_noops is True - assert config.enable_noop is None + assert config.enable_noop is True # Test enable_sequence_parallelism -> enable_sp caplog_vllm.clear() config = PassConfig(enable_sequence_parallelism=True) assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text assert config.enable_sp is True - assert config.enable_sequence_parallelism is None + assert config.enable_sequence_parallelism is True # Test enable_async_tp -> fuse_gemm_comms caplog_vllm.clear() config = PassConfig(enable_async_tp=True) assert "enable_async_tp is deprecated" in caplog_vllm.text assert config.fuse_gemm_comms is True - assert config.enable_async_tp is None + assert config.enable_async_tp is True # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms caplog_vllm.clear() config = PassConfig(enable_fi_allreduce_fusion=True) assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text assert config.fuse_allreduce_rms is True - assert config.enable_fi_allreduce_fusion is None + assert config.enable_fi_allreduce_fusion is True + + # Test hash consistency + config_old = PassConfig(enable_fusion=True) + config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True) + assert config_old.compute_hash() == config_new.compute_hash() + + config_old = PassConfig(enable_async_tp=True) + config_new = PassConfig(fuse_gemm_comms=True) + assert config_old.compute_hash() == config_new.compute_hash() diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 963b091939e0..d3d50e6ae7b2 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -4,7 +4,7 @@ import enum from collections import Counter from collections.abc import Callable -from dataclasses import asdict, field +from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config, handle_deprecated +from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -196,7 +196,16 @@ class PassConfig: Any new fields that affect compilation should be added to the hash. Any future fields that don't affect compilation should be excluded. """ - return InductorPass.hash_dict(asdict(self)) + + ignored_fields = [ + "enable_fusion", + "enable_attn_fusion", + "enable_noop", + "enable_sequence_parallelism", + "enable_async_tp", + "enable_fi_allreduce_fusion", + ] + return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields)) @field_validator( "fuse_norm_quant", @@ -267,14 +276,6 @@ class PassConfig: "v0.13.0 or v1.0.0, whichever is sooner", ) - # Force old flags to None to ensure they are not used - self.enable_fusion = None - self.enable_attn_fusion = None - self.enable_noop = None - self.enable_sequence_parallelism = None - self.enable_async_tp = None - self.enable_fi_allreduce_fusion = None - if not self.eliminate_noops: if self.fuse_norm_quant or self.fuse_act_quant: logger.warning_once(