mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Bugfix] Fix triton import with local TritonPlaceholder (#17446)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
05e1f96419
commit
f9bc5a0693
@ -10,12 +10,12 @@ from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import triton
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
@ -4,11 +4,11 @@ import itertools
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from torch import nn
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class HuggingFaceRMSNorm(nn.Module):
|
||||
|
||||
@ -6,13 +6,13 @@ import time
|
||||
# Import DeepGEMM functions
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import vLLM functions
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# Copied from
|
||||
|
||||
@ -5,11 +5,11 @@ import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
|
||||
92
tests/test_triton_utils.py
Normal file
92
tests/test_triton_utils.py
Normal file
@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import sys
|
||||
import types
|
||||
from unittest import mock
|
||||
|
||||
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
|
||||
TritonPlaceholder)
|
||||
|
||||
|
||||
def test_triton_placeholder_is_module():
|
||||
triton = TritonPlaceholder()
|
||||
assert isinstance(triton, types.ModuleType)
|
||||
assert triton.__name__ == "triton"
|
||||
|
||||
|
||||
def test_triton_language_placeholder_is_module():
|
||||
triton_language = TritonLanguagePlaceholder()
|
||||
assert isinstance(triton_language, types.ModuleType)
|
||||
assert triton_language.__name__ == "triton.language"
|
||||
|
||||
|
||||
def test_triton_placeholder_decorators():
|
||||
triton = TritonPlaceholder()
|
||||
|
||||
@triton.jit
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
@triton.autotune
|
||||
def bar(x):
|
||||
return x
|
||||
|
||||
@triton.heuristics
|
||||
def baz(x):
|
||||
return x
|
||||
|
||||
assert foo(1) == 1
|
||||
assert bar(2) == 2
|
||||
assert baz(3) == 3
|
||||
|
||||
|
||||
def test_triton_placeholder_decorators_with_args():
|
||||
triton = TritonPlaceholder()
|
||||
|
||||
@triton.jit(debug=True)
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
@triton.autotune(configs=[], key="x")
|
||||
def bar(x):
|
||||
return x
|
||||
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
|
||||
def baz(x):
|
||||
return x
|
||||
|
||||
assert foo(1) == 1
|
||||
assert bar(2) == 2
|
||||
assert baz(3) == 3
|
||||
|
||||
|
||||
def test_triton_placeholder_language():
|
||||
lang = TritonLanguagePlaceholder()
|
||||
assert isinstance(lang, types.ModuleType)
|
||||
assert lang.__name__ == "triton.language"
|
||||
assert lang.constexpr is None
|
||||
assert lang.dtype is None
|
||||
assert lang.int64 is None
|
||||
|
||||
|
||||
def test_triton_placeholder_language_from_parent():
|
||||
triton = TritonPlaceholder()
|
||||
lang = triton.language
|
||||
assert isinstance(lang, TritonLanguagePlaceholder)
|
||||
|
||||
|
||||
def test_no_triton_fallback():
|
||||
# clear existing triton modules
|
||||
sys.modules.pop("triton", None)
|
||||
sys.modules.pop("triton.language", None)
|
||||
sys.modules.pop("vllm.triton_utils", None)
|
||||
sys.modules.pop("vllm.triton_utils.importing", None)
|
||||
|
||||
# mock triton not being installed
|
||||
with mock.patch.dict(sys.modules, {"triton": None}):
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
assert HAS_TRITON is False
|
||||
assert triton.__class__.__name__ == "TritonPlaceholder"
|
||||
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
|
||||
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def blocksparse_flash_attn_varlen_fwd(
|
||||
|
||||
@ -8,7 +8,8 @@ from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
class csr_matrix:
|
||||
|
||||
@ -7,11 +7,10 @@
|
||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
|
||||
@ -4,10 +4,9 @@
|
||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
# Static kernels parameters
|
||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||
|
||||
@ -30,10 +30,8 @@ It supports page size >= 1.
|
||||
|
||||
import logging
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
is_hip_ = current_platform.is_rocm()
|
||||
|
||||
|
||||
@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features:
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
|
||||
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
|
||||
@ -2,8 +2,7 @@
|
||||
"""
|
||||
Utilities for Punica kernel construction.
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@ -6,8 +6,6 @@ import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||
|
||||
@ -2,11 +2,10 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||
|
||||
@ -4,13 +4,11 @@
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
|
||||
>= version.parse("3.0.0"))
|
||||
|
||||
@ -8,8 +8,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
|
||||
@ -6,10 +6,10 @@
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
@ -6,10 +6,11 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
||||
|
||||
@ -6,8 +6,8 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def is_weak_contiguous(x: torch.Tensor):
|
||||
|
||||
@ -7,8 +7,6 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -8,10 +8,9 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -1,5 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
|
||||
TritonPlaceholder)
|
||||
|
||||
__all__ = ["HAS_TRITON"]
|
||||
if HAS_TRITON:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
else:
|
||||
triton = TritonPlaceholder()
|
||||
tl = TritonLanguagePlaceholder()
|
||||
|
||||
__all__ = ["HAS_TRITON", "triton", "tl"]
|
||||
|
||||
@ -16,32 +16,34 @@ if not HAS_TRITON:
|
||||
logger.info("Triton not installed or not compatible; certain GPU-related"
|
||||
" functions will not be available.")
|
||||
|
||||
class TritonPlaceholder(types.ModuleType):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("triton")
|
||||
self.jit = self._dummy_decorator("jit")
|
||||
self.autotune = self._dummy_decorator("autotune")
|
||||
self.heuristics = self._dummy_decorator("heuristics")
|
||||
self.language = TritonLanguagePlaceholder()
|
||||
logger.warning_once(
|
||||
"Triton is not installed. Using dummy decorators. "
|
||||
"Install it via `pip install triton` to enable kernel"
|
||||
"compilation.")
|
||||
class TritonPlaceholder(types.ModuleType):
|
||||
|
||||
def _dummy_decorator(self, name):
|
||||
def __init__(self):
|
||||
super().__init__("triton")
|
||||
self.jit = self._dummy_decorator("jit")
|
||||
self.autotune = self._dummy_decorator("autotune")
|
||||
self.heuristics = self._dummy_decorator("heuristics")
|
||||
self.language = TritonLanguagePlaceholder()
|
||||
logger.warning_once(
|
||||
"Triton is not installed. Using dummy decorators. "
|
||||
"Install it via `pip install triton` to enable kernel"
|
||||
" compilation.")
|
||||
|
||||
def decorator(func=None, **kwargs):
|
||||
if func is None:
|
||||
return lambda f: f
|
||||
return func
|
||||
def _dummy_decorator(self, name):
|
||||
|
||||
return decorator
|
||||
def decorator(*args, **kwargs):
|
||||
if args and callable(args[0]):
|
||||
return args[0]
|
||||
return lambda f: f
|
||||
|
||||
class TritonLanguagePlaceholder(types.ModuleType):
|
||||
return decorator
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("triton.language")
|
||||
self.constexpr = None
|
||||
self.dtype = None
|
||||
self.int64 = None
|
||||
|
||||
class TritonLanguagePlaceholder(types.ModuleType):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("triton.language")
|
||||
self.constexpr = None
|
||||
self.dtype = None
|
||||
self.int64 = None
|
||||
|
||||
@ -3,10 +3,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user