From c1acd6d7d48505d070546b7afa922e4a93ac5447 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Sat, 12 Jul 2025 22:39:55 -0400 Subject: [PATCH] [Refactor] Change the way of import triton (#20774) Signed-off-by: yewentao256 --- tests/kernels/moe/test_batched_moe.py | 2 +- vllm/attention/ops/triton_unified_attention.py | 3 +-- vllm/lora/ops/triton_ops/lora_expand_op.py | 3 +-- vllm/lora/ops/triton_ops/lora_shrink_op.py | 3 +-- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 3 +-- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index c9a4375ac939e..69317405d48b7 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -6,7 +6,6 @@ from typing import Optional import pytest import torch -import triton.language as tl from tests.kernels.moe.utils import (batched_moe, make_quantized_test_activations, @@ -18,6 +17,7 @@ from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( invoke_moe_batched_triton_kernel) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform +from vllm.triton_utils import tl MNK_FACTORS = [ (1, 128, 128), diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index f9645f651351f..eb9c4f1c1030a 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -8,10 +8,9 @@ # - Thomas Parnell import torch -import triton -import triton.language as tl from vllm.logger import init_logger +from vllm.triton_utils import tl, triton logger = init_logger(__name__) diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index eaef8e2c1905e..b1ab84e08ba76 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -8,12 +8,11 @@ https://arxiv.org/abs/2310.18547 """ import torch -import triton -import triton.language as tl from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index d299fa5e8e1a5..1e7075ab07151 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -8,12 +8,11 @@ https://arxiv.org/abs/2310.18547 """ import torch -import triton -import triton.language as tl from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 34f8c124759a8..61247e93091f1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -4,8 +4,6 @@ from typing import Optional import torch -import triton -import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig @@ -18,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( normalize_scales_shape) from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) +from vllm.triton_utils import tl, triton @triton.jit