mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
[MISC][pre-commit] Add pre-commit check for triton import (#17716)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
07458a51ce
commit
e77dc4bad8
@ -128,6 +128,13 @@ repos:
|
||||
name: Update Dockerfile dependency graph
|
||||
entry: tools/update-dockerfile-graph.sh
|
||||
language: script
|
||||
# forbid directly import triton
|
||||
- id: forbid-direct-triton-import
|
||||
name: "Forbid direct 'import triton'"
|
||||
entry: python tools/check_triton_import.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
75
tools/check_triton_import.py
Normal file
75
tools/check_triton_import.py
Normal file
@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")
|
||||
|
||||
# the way allowed to import triton
|
||||
ALLOWED_LINES = {
|
||||
"from vllm.triton_utils import triton",
|
||||
"from vllm.triton_utils import tl",
|
||||
"from vllm.triton_utils import tl, triton",
|
||||
}
|
||||
|
||||
|
||||
def is_forbidden_import(line: str) -> bool:
|
||||
stripped = line.strip()
|
||||
return bool(
|
||||
FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES
|
||||
|
||||
|
||||
def parse_diff(diff: str) -> list[str]:
|
||||
violations = []
|
||||
current_file = None
|
||||
current_lineno = None
|
||||
|
||||
for line in diff.splitlines():
|
||||
if line.startswith("+++ b/"):
|
||||
current_file = line[6:]
|
||||
elif line.startswith("@@"):
|
||||
match = re.search(r"\+(\d+)", line)
|
||||
if match:
|
||||
current_lineno = int(
|
||||
match.group(1)) - 1 # next "+ line" is here
|
||||
elif line.startswith("+") and not line.startswith("++"):
|
||||
current_lineno += 1
|
||||
code_line = line[1:]
|
||||
if is_forbidden_import(code_line):
|
||||
violations.append(
|
||||
f"{current_file}:{current_lineno}: {code_line.strip()}")
|
||||
return violations
|
||||
|
||||
|
||||
def get_diff(diff_type: str) -> str:
|
||||
if diff_type == "staged":
|
||||
return subprocess.check_output(
|
||||
["git", "diff", "--cached", "--unified=0"], text=True)
|
||||
elif diff_type == "unstaged":
|
||||
return subprocess.check_output(["git", "diff", "--unified=0"],
|
||||
text=True)
|
||||
else:
|
||||
raise ValueError(f"Unknown diff_type: {diff_type}")
|
||||
|
||||
|
||||
def main():
|
||||
all_violations = []
|
||||
for diff_type in ["staged", "unstaged"]:
|
||||
try:
|
||||
diff_output = get_diff(diff_type)
|
||||
violations = parse_diff(diff_output)
|
||||
all_violations.extend(violations)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)
|
||||
|
||||
if all_violations:
|
||||
print("❌ Forbidden direct `import triton` detected."
|
||||
" ➤ Use `from vllm.triton_utils import triton` instead.\n")
|
||||
for v in all_violations:
|
||||
print(f"❌ {v}")
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Loading…
x
Reference in New Issue
Block a user