diff --git a/tools/check_triton_import.py b/tools/check_triton_import.py index 77b2dfc39188..c01d9d4ab079 100644 --- a/tools/check_triton_import.py +++ b/tools/check_triton_import.py @@ -14,6 +14,12 @@ ALLOWED_LINES = { "from vllm.triton_utils import tl, triton", } +ALLOWED_FILES = {"vllm/triton_utils/importing.py"} + + +def is_allowed_file(current_file: str) -> bool: + return current_file in ALLOWED_FILES + def is_forbidden_import(line: str) -> bool: stripped = line.strip() @@ -25,10 +31,14 @@ def parse_diff(diff: str) -> list[str]: violations = [] current_file = None current_lineno = None + skip_allowed_file = False for line in diff.splitlines(): if line.startswith("+++ b/"): current_file = line[6:] + skip_allowed_file = is_allowed_file(current_file) + elif skip_allowed_file: + continue elif line.startswith("@@"): match = re.search(r"\+(\d+)", line) if match: