mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:44:27 +08:00
[Core] Remove forced None assignment for deprecated PassConfig flags (#29994)
Signed-off-by: arpitkh101 <arpit5khandelwal@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ffdd18111b
commit
dfdda96747
@ -392,39 +392,48 @@ def test_pass_config_deprecation(caplog_vllm):
|
|||||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
assert "enable_fusion is deprecated" in caplog_vllm.text
|
||||||
assert config.fuse_norm_quant is True
|
assert config.fuse_norm_quant is True
|
||||||
assert config.fuse_act_quant is True
|
assert config.fuse_act_quant is True
|
||||||
assert config.enable_fusion is None
|
assert config.enable_fusion is True
|
||||||
|
|
||||||
# Test enable_attn_fusion -> fuse_attn_quant
|
# Test enable_attn_fusion -> fuse_attn_quant
|
||||||
caplog_vllm.clear()
|
caplog_vllm.clear()
|
||||||
config = PassConfig(enable_attn_fusion=True)
|
config = PassConfig(enable_attn_fusion=True)
|
||||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
||||||
assert config.fuse_attn_quant is True
|
assert config.fuse_attn_quant is True
|
||||||
assert config.enable_attn_fusion is None
|
assert config.enable_attn_fusion is True
|
||||||
|
|
||||||
# Test enable_noop -> eliminate_noops
|
# Test enable_noop -> eliminate_noops
|
||||||
caplog_vllm.clear()
|
caplog_vllm.clear()
|
||||||
config = PassConfig(enable_noop=True)
|
config = PassConfig(enable_noop=True)
|
||||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
assert "enable_noop is deprecated" in caplog_vllm.text
|
||||||
assert config.eliminate_noops is True
|
assert config.eliminate_noops is True
|
||||||
assert config.enable_noop is None
|
assert config.enable_noop is True
|
||||||
|
|
||||||
# Test enable_sequence_parallelism -> enable_sp
|
# Test enable_sequence_parallelism -> enable_sp
|
||||||
caplog_vllm.clear()
|
caplog_vllm.clear()
|
||||||
config = PassConfig(enable_sequence_parallelism=True)
|
config = PassConfig(enable_sequence_parallelism=True)
|
||||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
||||||
assert config.enable_sp is True
|
assert config.enable_sp is True
|
||||||
assert config.enable_sequence_parallelism is None
|
assert config.enable_sequence_parallelism is True
|
||||||
|
|
||||||
# Test enable_async_tp -> fuse_gemm_comms
|
# Test enable_async_tp -> fuse_gemm_comms
|
||||||
caplog_vllm.clear()
|
caplog_vllm.clear()
|
||||||
config = PassConfig(enable_async_tp=True)
|
config = PassConfig(enable_async_tp=True)
|
||||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
||||||
assert config.fuse_gemm_comms is True
|
assert config.fuse_gemm_comms is True
|
||||||
assert config.enable_async_tp is None
|
assert config.enable_async_tp is True
|
||||||
|
|
||||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
||||||
caplog_vllm.clear()
|
caplog_vllm.clear()
|
||||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
config = PassConfig(enable_fi_allreduce_fusion=True)
|
||||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
||||||
assert config.fuse_allreduce_rms is True
|
assert config.fuse_allreduce_rms is True
|
||||||
assert config.enable_fi_allreduce_fusion is None
|
assert config.enable_fi_allreduce_fusion is True
|
||||||
|
|
||||||
|
# Test hash consistency
|
||||||
|
config_old = PassConfig(enable_fusion=True)
|
||||||
|
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
|
||||||
|
assert config_old.compute_hash() == config_new.compute_hash()
|
||||||
|
|
||||||
|
config_old = PassConfig(enable_async_tp=True)
|
||||||
|
config_new = PassConfig(fuse_gemm_comms=True)
|
||||||
|
assert config_old.compute_hash() == config_new.compute_hash()
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
import enum
|
import enum
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import asdict, field
|
from dataclasses import field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||||
from vllm.config.utils import config, handle_deprecated
|
from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
@ -196,7 +196,16 @@ class PassConfig:
|
|||||||
Any new fields that affect compilation should be added to the hash.
|
Any new fields that affect compilation should be added to the hash.
|
||||||
Any future fields that don't affect compilation should be excluded.
|
Any future fields that don't affect compilation should be excluded.
|
||||||
"""
|
"""
|
||||||
return InductorPass.hash_dict(asdict(self))
|
|
||||||
|
ignored_fields = [
|
||||||
|
"enable_fusion",
|
||||||
|
"enable_attn_fusion",
|
||||||
|
"enable_noop",
|
||||||
|
"enable_sequence_parallelism",
|
||||||
|
"enable_async_tp",
|
||||||
|
"enable_fi_allreduce_fusion",
|
||||||
|
]
|
||||||
|
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"fuse_norm_quant",
|
"fuse_norm_quant",
|
||||||
@ -267,14 +276,6 @@ class PassConfig:
|
|||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
"v0.13.0 or v1.0.0, whichever is sooner",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Force old flags to None to ensure they are not used
|
|
||||||
self.enable_fusion = None
|
|
||||||
self.enable_attn_fusion = None
|
|
||||||
self.enable_noop = None
|
|
||||||
self.enable_sequence_parallelism = None
|
|
||||||
self.enable_async_tp = None
|
|
||||||
self.enable_fi_allreduce_fusion = None
|
|
||||||
|
|
||||||
if not self.eliminate_noops:
|
if not self.eliminate_noops:
|
||||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user