[Bugfix] Fix triton import with local TritonPlaceholder (#17446)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-05-06 17:53:09 +08:00 committed by GitHub
parent 05e1f96419
commit f9bc5a0693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 165 additions and 75 deletions

View File

@ -10,12 +10,12 @@ from typing import Any, TypedDict
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()

View File

@ -4,11 +4,11 @@ import itertools
from typing import Optional, Union
import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module):

View File

@ -6,13 +6,13 @@ import time
# Import DeepGEMM functions
import deep_gemm
import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton
# Copied from

View File

@ -5,11 +5,11 @@ import random
import pytest
import torch
import triton
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:

View File

@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
import sys
import types
from unittest import mock
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
TritonPlaceholder)
def test_triton_placeholder_is_module():
triton = TritonPlaceholder()
assert isinstance(triton, types.ModuleType)
assert triton.__name__ == "triton"
def test_triton_language_placeholder_is_module():
triton_language = TritonLanguagePlaceholder()
assert isinstance(triton_language, types.ModuleType)
assert triton_language.__name__ == "triton.language"
def test_triton_placeholder_decorators():
triton = TritonPlaceholder()
@triton.jit
def foo(x):
return x
@triton.autotune
def bar(x):
return x
@triton.heuristics
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_decorators_with_args():
triton = TritonPlaceholder()
@triton.jit(debug=True)
def foo(x):
return x
@triton.autotune(configs=[], key="x")
def bar(x):
return x
@triton.heuristics(
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_language():
lang = TritonLanguagePlaceholder()
assert isinstance(lang, types.ModuleType)
assert lang.__name__ == "triton.language"
assert lang.constexpr is None
assert lang.dtype is None
assert lang.int64 is None
def test_triton_placeholder_language_from_parent():
triton = TritonPlaceholder()
lang = triton.language
assert isinstance(lang, TritonLanguagePlaceholder)
def test_no_triton_fallback():
# clear existing triton modules
sys.modules.pop("triton", None)
sys.modules.pop("triton.language", None)
sys.modules.pop("vllm.triton_utils", None)
sys.modules.pop("vllm.triton_utils.importing", None)
# mock triton not being installed
with mock.patch.dict(sys.modules, {"triton": None}):
from vllm.triton_utils import HAS_TRITON, tl, triton
assert HAS_TRITON is False
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
def blocksparse_flash_attn_varlen_fwd(

View File

@ -8,7 +8,8 @@ from functools import lru_cache
import numpy as np
import torch
import triton
from vllm.triton_utils import triton
class csr_matrix:

View File

@ -7,11 +7,10 @@
# - Thomas Parnell <tpa@zurich.ibm.com>
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd

View File

@ -4,10 +4,9 @@
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
# Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

View File

@ -30,10 +30,8 @@ It supports page size >= 1.
import logging
import triton
import triton.language as tl
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
is_hip_ = current_platform.is_rocm()

View File

@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features:
from typing import Optional
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']

View File

@ -2,8 +2,8 @@
from typing import Optional
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005

View File

@ -2,8 +2,7 @@
"""
Utilities for Punica kernel construction.
"""
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
@triton.jit

View File

@ -6,8 +6,6 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

View File

@ -2,11 +2,10 @@
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import round_up

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from einops import rearrange
from vllm.triton_utils import tl, triton
@triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,

View File

@ -4,13 +4,11 @@
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch
import triton
import triton.language as tl
from packaging import version
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON
from vllm.triton_utils import HAS_TRITON, tl, triton
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0"))

View File

@ -8,8 +8,8 @@
import math
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
@triton.autotune(

View File

@ -6,10 +6,10 @@
# ruff: noqa: E501,SIM102
import torch
import triton
import triton.language as tl
from packaging import version
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')

View File

@ -8,8 +8,8 @@
import math
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
from .mamba_ssm import softplus

View File

@ -6,10 +6,11 @@
# ruff: noqa: E501
import torch
import triton
from einops import rearrange
from packaging import version
from vllm.triton_utils import triton
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,

View File

@ -6,8 +6,8 @@
# ruff: noqa: E501
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
@triton.autotune(

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

View File

@ -3,8 +3,8 @@
from typing import Optional, Type
import torch
import triton
import triton.language as tl
from vllm.triton_utils import tl, triton
def is_weak_contiguous(x: torch.Tensor):

View File

@ -7,8 +7,6 @@ import os
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
logger = init_logger(__name__)

View File

@ -8,10 +8,9 @@ import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = logging.getLogger(__name__)

View File

@ -1,5 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils.importing import HAS_TRITON
from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
TritonPlaceholder)
__all__ = ["HAS_TRITON"]
if HAS_TRITON:
import triton
import triton.language as tl
else:
triton = TritonPlaceholder()
tl = TritonLanguagePlaceholder()
__all__ = ["HAS_TRITON", "triton", "tl"]

View File

@ -16,32 +16,34 @@ if not HAS_TRITON:
logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.")
class TritonPlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
"compilation.")
class TritonPlaceholder(types.ModuleType):
def _dummy_decorator(self, name):
def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
" compilation.")
def decorator(func=None, **kwargs):
if func is None:
return lambda f: f
return func
def _dummy_decorator(self, name):
return decorator
def decorator(*args, **kwargs):
if args and callable(args[0]):
return args[0]
return lambda f: f
class TritonLanguagePlaceholder(types.ModuleType):
return decorator
def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None
class TritonLanguagePlaceholder(types.ModuleType):
def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None

View File

@ -3,10 +3,9 @@ from typing import Optional
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata