[Core] Remove forced None assignment for deprecated PassConfig flags (#29994)

Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Arpit Khandelwal 2025-12-04 04:15:04 -05:00 committed by GitHub
parent ffdd18111b
commit dfdda96747
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 17 deletions

View File

@ -392,39 +392,48 @@ def test_pass_config_deprecation(caplog_vllm):
assert "enable_fusion is deprecated" in caplog_vllm.text assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True assert config.fuse_act_quant is True
assert config.enable_fusion is None assert config.enable_fusion is True
# Test enable_attn_fusion -> fuse_attn_quant # Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True) config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is None assert config.enable_attn_fusion is True
# Test enable_noop -> eliminate_noops # Test enable_noop -> eliminate_noops
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_noop=True) config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True assert config.eliminate_noops is True
assert config.enable_noop is None assert config.enable_noop is True
# Test enable_sequence_parallelism -> enable_sp # Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True) config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True assert config.enable_sp is True
assert config.enable_sequence_parallelism is None assert config.enable_sequence_parallelism is True
# Test enable_async_tp -> fuse_gemm_comms # Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_async_tp=True) config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True assert config.fuse_gemm_comms is True
assert config.enable_async_tp is None assert config.enable_async_tp is True
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear() caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True) config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is None assert config.enable_fi_allreduce_fusion is True
# Test hash consistency
config_old = PassConfig(enable_fusion=True)
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
assert config_old.compute_hash() == config_new.compute_hash()
config_old = PassConfig(enable_async_tp=True)
config_new = PassConfig(fuse_gemm_comms=True)
assert config_old.compute_hash() == config_new.compute_hash()

View File

@ -4,7 +4,7 @@
import enum import enum
from collections import Counter from collections import Counter
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict, field from dataclasses import field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal from typing import TYPE_CHECKING, Any, ClassVar, Literal
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config, handle_deprecated from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
@ -196,7 +196,16 @@ class PassConfig:
Any new fields that affect compilation should be added to the hash. Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded. Any future fields that don't affect compilation should be excluded.
""" """
return InductorPass.hash_dict(asdict(self))
ignored_fields = [
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
@field_validator( @field_validator(
"fuse_norm_quant", "fuse_norm_quant",
@ -267,14 +276,6 @@ class PassConfig:
"v0.13.0 or v1.0.0, whichever is sooner", "v0.13.0 or v1.0.0, whichever is sooner",
) )
# Force old flags to None to ensure they are not used
self.enable_fusion = None
self.enable_attn_fusion = None
self.enable_noop = None
self.enable_sequence_parallelism = None
self.enable_async_tp = None
self.enable_fi_allreduce_fusion = None
if not self.eliminate_noops: if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant: if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once( logger.warning_once(