[Bugfix] Fix triton import with local TritonPlaceholder (#17446)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-05-06 17:53:09 +08:00 committed by GitHub
parent 05e1f96419
commit f9bc5a0693
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 165 additions and 75 deletions

View File

@ -10,12 +10,12 @@ from typing import Any, TypedDict
import ray import ray
import torch import torch
import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()

View File

@ -4,11 +4,11 @@ import itertools
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn from torch import nn
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module): class HuggingFaceRMSNorm(nn.Module):

View File

@ -6,13 +6,13 @@ import time
# Import DeepGEMM functions # Import DeepGEMM functions
import deep_gemm import deep_gemm
import torch import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions # Import vLLM functions
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton
# Copied from # Copied from

View File

@ -5,11 +5,11 @@ import random
import pytest import pytest
import torch import torch
import triton
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.triton_utils import triton
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:

View File

@ -0,0 +1,92 @@
# SPDX-License-Identifier: Apache-2.0
import sys
import types
from unittest import mock
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
TritonPlaceholder)
def test_triton_placeholder_is_module():
triton = TritonPlaceholder()
assert isinstance(triton, types.ModuleType)
assert triton.__name__ == "triton"
def test_triton_language_placeholder_is_module():
triton_language = TritonLanguagePlaceholder()
assert isinstance(triton_language, types.ModuleType)
assert triton_language.__name__ == "triton.language"
def test_triton_placeholder_decorators():
triton = TritonPlaceholder()
@triton.jit
def foo(x):
return x
@triton.autotune
def bar(x):
return x
@triton.heuristics
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_decorators_with_args():
triton = TritonPlaceholder()
@triton.jit(debug=True)
def foo(x):
return x
@triton.autotune(configs=[], key="x")
def bar(x):
return x
@triton.heuristics(
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
def baz(x):
return x
assert foo(1) == 1
assert bar(2) == 2
assert baz(3) == 3
def test_triton_placeholder_language():
lang = TritonLanguagePlaceholder()
assert isinstance(lang, types.ModuleType)
assert lang.__name__ == "triton.language"
assert lang.constexpr is None
assert lang.dtype is None
assert lang.int64 is None
def test_triton_placeholder_language_from_parent():
triton = TritonPlaceholder()
lang = triton.language
assert isinstance(lang, TritonLanguagePlaceholder)
def test_no_triton_fallback():
# clear existing triton modules
sys.modules.pop("triton", None)
sys.modules.pop("triton.language", None)
sys.modules.pop("vllm.triton_utils", None)
sys.modules.pop("vllm.triton_utils.importing", None)
# mock triton not being installed
with mock.patch.dict(sys.modules, {"triton": None}):
from vllm.triton_utils import HAS_TRITON, tl, triton
assert HAS_TRITON is False
assert triton.__class__.__name__ == "TritonPlaceholder"
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
def blocksparse_flash_attn_varlen_fwd( def blocksparse_flash_attn_varlen_fwd(

View File

@ -8,7 +8,8 @@ from functools import lru_cache
import numpy as np import numpy as np
import torch import torch
import triton
from vllm.triton_utils import triton
class csr_matrix: class csr_matrix:

View File

@ -7,11 +7,10 @@
# - Thomas Parnell <tpa@zurich.ibm.com> # - Thomas Parnell <tpa@zurich.ibm.com>
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd from .prefix_prefill import context_attention_fwd

View File

@ -4,10 +4,9 @@
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
# Static kernels parameters # Static kernels parameters
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64

View File

@ -30,10 +30,8 @@ It supports page size >= 1.
import logging import logging
import triton
import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
is_hip_ = current_platform.is_rocm() is_hip_ = current_platform.is_rocm()

View File

@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features:
from typing import Optional from typing import Optional
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']

View File

@ -2,8 +2,8 @@
from typing import Optional from typing import Optional
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005

View File

@ -2,8 +2,7 @@
""" """
Utilities for Punica kernel construction. Utilities for Punica kernel construction.
""" """
import triton from vllm.triton_utils import tl, triton
import triton.language as tl
@triton.jit @triton.jit

View File

@ -6,8 +6,6 @@ import os
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
import triton
import triton.language as tl
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.int8_utils import ( from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled

View File

@ -2,11 +2,10 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import triton
import triton.language as tl
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import round_up from vllm.utils import round_up

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton
import triton.language as tl
from einops import rearrange from einops import rearrange
from vllm.triton_utils import tl, triton
@triton.jit @triton.jit
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n, def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,

View File

@ -4,13 +4,11 @@
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch import torch
import triton
import triton.language as tl
from packaging import version from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON, tl, triton
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0")) >= version.parse("3.0.0"))

View File

@ -8,8 +8,8 @@
import math import math
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
@triton.autotune( @triton.autotune(

View File

@ -6,10 +6,10 @@
# ruff: noqa: E501,SIM102 # ruff: noqa: E501,SIM102
import torch import torch
import triton
import triton.language as tl
from packaging import version from packaging import version
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')

View File

@ -8,8 +8,8 @@
import math import math
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
from .mamba_ssm import softplus from .mamba_ssm import softplus

View File

@ -6,10 +6,11 @@
# ruff: noqa: E501 # ruff: noqa: E501
import torch import torch
import triton
from einops import rearrange from einops import rearrange
from packaging import version from packaging import version
from vllm.triton_utils import triton
from .ssd_bmm import _bmm_chunk_fwd from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,

View File

@ -6,8 +6,8 @@
# ruff: noqa: E501 # ruff: noqa: E501
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
@triton.autotune( @triton.autotune(

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

View File

@ -3,8 +3,8 @@
from typing import Optional, Type from typing import Optional, Type
import torch import torch
import triton
import triton.language as tl from vllm.triton_utils import tl, triton
def is_weak_contiguous(x: torch.Tensor): def is_weak_contiguous(x: torch.Tensor):

View File

@ -7,8 +7,6 @@ import os
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED) CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -8,10 +8,9 @@ import os
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
import triton
import triton.language as tl
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -1,5 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils.importing import HAS_TRITON from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
TritonPlaceholder)
__all__ = ["HAS_TRITON"] if HAS_TRITON:
import triton
import triton.language as tl
else:
triton = TritonPlaceholder()
tl = TritonLanguagePlaceholder()
__all__ = ["HAS_TRITON", "triton", "tl"]

View File

@ -16,32 +16,34 @@ if not HAS_TRITON:
logger.info("Triton not installed or not compatible; certain GPU-related" logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.") " functions will not be available.")
class TritonPlaceholder(types.ModuleType):
def __init__(self): class TritonPlaceholder(types.ModuleType):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
"compilation.")
def _dummy_decorator(self, name): def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
" compilation.")
def decorator(func=None, **kwargs): def _dummy_decorator(self, name):
if func is None:
return lambda f: f
return func
return decorator def decorator(*args, **kwargs):
if args and callable(args[0]):
return args[0]
return lambda f: f
class TritonLanguagePlaceholder(types.ModuleType): return decorator
def __init__(self):
super().__init__("triton.language") class TritonLanguagePlaceholder(types.ModuleType):
self.constexpr = None
self.dtype = None def __init__(self):
self.int64 = None super().__init__("triton.language")
self.constexpr = None
self.dtype = None
self.int64 = None

View File

@ -3,10 +3,9 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import triton
import triton.language as tl
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import torch import torch
import torch.nn as nn import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata