diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 9407747f7843..1884a80a4077 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,12 +10,12 @@ from typing import Any, TypedDict import ray import torch -import triton from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform +from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index eaf6b25e8ca4..09a319ccf1d1 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -4,11 +4,11 @@ import itertools from typing import Optional, Union import torch -import triton from flashinfer.norm import fused_add_rmsnorm, rmsnorm from torch import nn from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton class HuggingFaceRMSNorm(nn.Module): diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 7892f126e7d6..5fa55bb974e1 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -6,13 +6,13 @@ import time # Import DeepGEMM functions import deep_gemm import torch -import triton from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor # Import vLLM functions from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul) +from vllm.triton_utils import triton # Copied from diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 3985c6834f60..0d51a8e7fee1 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -5,11 +5,11 @@ import random import pytest import torch -import triton from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) +from vllm.triton_utils import triton def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: diff --git a/tests/test_triton_utils.py b/tests/test_triton_utils.py new file mode 100644 index 000000000000..eb8ad48fdead --- /dev/null +++ b/tests/test_triton_utils.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 + +import sys +import types +from unittest import mock + +from vllm.triton_utils.importing import (TritonLanguagePlaceholder, + TritonPlaceholder) + + +def test_triton_placeholder_is_module(): + triton = TritonPlaceholder() + assert isinstance(triton, types.ModuleType) + assert triton.__name__ == "triton" + + +def test_triton_language_placeholder_is_module(): + triton_language = TritonLanguagePlaceholder() + assert isinstance(triton_language, types.ModuleType) + assert triton_language.__name__ == "triton.language" + + +def test_triton_placeholder_decorators(): + triton = TritonPlaceholder() + + @triton.jit + def foo(x): + return x + + @triton.autotune + def bar(x): + return x + + @triton.heuristics + def baz(x): + return x + + assert foo(1) == 1 + assert bar(2) == 2 + assert baz(3) == 3 + + +def test_triton_placeholder_decorators_with_args(): + triton = TritonPlaceholder() + + @triton.jit(debug=True) + def foo(x): + return x + + @triton.autotune(configs=[], key="x") + def bar(x): + return x + + @triton.heuristics( + {"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64}) + def baz(x): + return x + + assert foo(1) == 1 + assert bar(2) == 2 + assert baz(3) == 3 + + +def test_triton_placeholder_language(): + lang = TritonLanguagePlaceholder() + assert isinstance(lang, types.ModuleType) + assert lang.__name__ == "triton.language" + assert lang.constexpr is None + assert lang.dtype is None + assert lang.int64 is None + + +def test_triton_placeholder_language_from_parent(): + triton = TritonPlaceholder() + lang = triton.language + assert isinstance(lang, TritonLanguagePlaceholder) + + +def test_no_triton_fallback(): + # clear existing triton modules + sys.modules.pop("triton", None) + sys.modules.pop("triton.language", None) + sys.modules.pop("vllm.triton_utils", None) + sys.modules.pop("vllm.triton_utils.importing", None) + + # mock triton not being installed + with mock.patch.dict(sys.modules, {"triton": None}): + from vllm.triton_utils import HAS_TRITON, tl, triton + assert HAS_TRITON is False + assert triton.__class__.__name__ == "TritonPlaceholder" + assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder" + assert tl.__class__.__name__ == "TritonLanguagePlaceholder" diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py index 71caf3cbac02..bc87ce33a301 100644 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton def blocksparse_flash_attn_varlen_fwd( diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py index 4de9bd530642..e64fc1139713 100644 --- a/vllm/attention/ops/blocksparse_attention/utils.py +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -8,7 +8,8 @@ from functools import lru_cache import numpy as np import torch -import triton + +from vllm.triton_utils import triton class csr_matrix: diff --git a/vllm/attention/ops/chunked_prefill_paged_decode.py b/vllm/attention/ops/chunked_prefill_paged_decode.py index 759b3d8536dd..dc039a0259aa 100644 --- a/vllm/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/attention/ops/chunked_prefill_paged_decode.py @@ -7,11 +7,10 @@ # - Thomas Parnell import torch -import triton -import triton.language as tl from vllm import _custom_ops as ops from vllm.platforms.rocm import use_rocm_custom_paged_attention +from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a8c8d8409620..86d256b630bf 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -4,10 +4,9 @@ # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py import torch -import triton -import triton.language as tl from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton # Static kernels parameters BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index 35ee0835f42a..fb983907e375 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -30,10 +30,8 @@ It supports page size >= 1. import logging -import triton -import triton.language as tl - from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton is_hip_ = current_platform.is_rocm() diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 23ac7d7dc84c..8940d0b66225 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features: from typing import Optional import torch -import triton -import triton.language as tl from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] diff --git a/vllm/attention/ops/triton_merge_attn_states.py b/vllm/attention/ops/triton_merge_attn_states.py index 250426d9faa5..30e61b6d8263 100644 --- a/vllm/attention/ops/triton_merge_attn_states.py +++ b/vllm/attention/ops/triton_merge_attn_states.py @@ -2,8 +2,8 @@ from typing import Optional import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton # Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py index 5b8c19376106..0f971c03592d 100644 --- a/vllm/lora/ops/triton_ops/kernel_utils.py +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -2,8 +2,7 @@ """ Utilities for Punica kernel construction. """ -import triton -import triton.language as tl +from vllm.triton_utils import tl, triton @triton.jit diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c1edbda0dd22..075b98d14860 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -6,8 +6,6 @@ import os from typing import Any, Callable, Dict, List, Optional, Tuple import torch -import triton -import triton.language as tl import vllm.envs as envs from vllm import _custom_ops as ops @@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 07d51acf9867..b68e58efa884 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -2,11 +2,10 @@ from typing import Optional, Tuple import torch -import triton -import triton.language as tl import vllm.envs as envs from vllm import _custom_ops as ops +from vllm.triton_utils import tl, triton from vllm.utils import round_up diff --git a/vllm/model_executor/layers/lightning_attn.py b/vllm/model_executor/layers/lightning_attn.py index de360778f28c..96659af408ed 100644 --- a/vllm/model_executor/layers/lightning_attn.py +++ b/vllm/model_executor/layers/lightning_attn.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import triton -import triton.language as tl from einops import rearrange +from vllm.triton_utils import tl, triton + @triton.jit def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 9fbad9d2f91e..689c940d11ba 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -4,13 +4,11 @@ # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py import torch -import triton -import triton.language as tl from packaging import version from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.triton_utils import HAS_TRITON +from vllm.triton_utils import HAS_TRITON, tl, triton TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 388a63327213..0fdb055aab82 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -8,8 +8,8 @@ import math import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton @triton.autotune( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 005917f23638..1652c51814cd 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -6,10 +6,10 @@ # ruff: noqa: E501,SIM102 import torch -import triton -import triton.language as tl from packaging import version +from vllm.triton_utils import tl, triton + TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index a970ac94580b..ee633569097b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -8,8 +8,8 @@ import math import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton from .mamba_ssm import softplus diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 3febd4ccb992..e9efe6428252 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -6,10 +6,11 @@ # ruff: noqa: E501 import torch -import triton from einops import rearrange from packaging import version +from vllm.triton_utils import triton + from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 219c5306f425..6f69ca74389e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -6,8 +6,8 @@ # ruff: noqa: E501 import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton @triton.autotune( diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index 09efd4dbd797..5e5491578979 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py index b69c5e7a02a7..d5d98ee8ba4d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py @@ -3,8 +3,8 @@ from typing import Optional, Type import torch -import triton -import triton.language as tl + +from vllm.triton_utils import tl, triton def is_weak_contiguous(x: torch.Tensor): diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ecb7996e1e8c..064cbb8cf52d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -7,8 +7,6 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union import torch -import triton -import triton.language as tl from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/quantization/utils/int8_utils.py b/vllm/model_executor/layers/quantization/utils/int8_utils.py index aaaf7a9e0a4c..431f0cf73fad 100644 --- a/vllm/model_executor/layers/quantization/utils/int8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/int8_utils.py @@ -8,10 +8,9 @@ import os from typing import Any, Dict, List, Optional, Tuple import torch -import triton -import triton.language as tl from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton logger = logging.getLogger(__name__) diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index bffc56a2e75c..9f14a907af3a 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -1,5 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 -from vllm.triton_utils.importing import HAS_TRITON +from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder, + TritonPlaceholder) -__all__ = ["HAS_TRITON"] +if HAS_TRITON: + import triton + import triton.language as tl +else: + triton = TritonPlaceholder() + tl = TritonLanguagePlaceholder() + +__all__ = ["HAS_TRITON", "triton", "tl"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 0a0c0a4bd178..8cf2e01a33bd 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -16,32 +16,34 @@ if not HAS_TRITON: logger.info("Triton not installed or not compatible; certain GPU-related" " functions will not be available.") - class TritonPlaceholder(types.ModuleType): - def __init__(self): - super().__init__("triton") - self.jit = self._dummy_decorator("jit") - self.autotune = self._dummy_decorator("autotune") - self.heuristics = self._dummy_decorator("heuristics") - self.language = TritonLanguagePlaceholder() - logger.warning_once( - "Triton is not installed. Using dummy decorators. " - "Install it via `pip install triton` to enable kernel" - "compilation.") +class TritonPlaceholder(types.ModuleType): - def _dummy_decorator(self, name): + def __init__(self): + super().__init__("triton") + self.jit = self._dummy_decorator("jit") + self.autotune = self._dummy_decorator("autotune") + self.heuristics = self._dummy_decorator("heuristics") + self.language = TritonLanguagePlaceholder() + logger.warning_once( + "Triton is not installed. Using dummy decorators. " + "Install it via `pip install triton` to enable kernel" + " compilation.") - def decorator(func=None, **kwargs): - if func is None: - return lambda f: f - return func + def _dummy_decorator(self, name): - return decorator + def decorator(*args, **kwargs): + if args and callable(args[0]): + return args[0] + return lambda f: f - class TritonLanguagePlaceholder(types.ModuleType): + return decorator - def __init__(self): - super().__init__("triton.language") - self.constexpr = None - self.dtype = None - self.int64 = None + +class TritonLanguagePlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.language") + self.constexpr = None + self.dtype = None + self.int64 = None diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b25443dd45ed..17b870fede8e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -3,10 +3,9 @@ from typing import Optional import torch import torch.nn as nn -import triton -import triton.language as tl from vllm.logger import init_logger +from vllm.triton_utils import tl, triton from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.spec_decode.metadata import SpecDecodeMetadata diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 07097d7da68f..6d71743c5e36 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import torch import torch.nn as nn -import triton -import triton.language as tl from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context @@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata