mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 00:45:26 +08:00
[Bugfix] Fix triton import with local TritonPlaceholder (#17446)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
parent
05e1f96419
commit
f9bc5a0693
@ -10,12 +10,12 @@ from typing import Any, TypedDict
|
|||||||
|
|
||||||
import ray
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
from ray.experimental.tqdm_ray import tqdm
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|||||||
@ -4,11 +4,11 @@ import itertools
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceRMSNorm(nn.Module):
|
class HuggingFaceRMSNorm(nn.Module):
|
||||||
|
|||||||
@ -6,13 +6,13 @@ import time
|
|||||||
# Import DeepGEMM functions
|
# Import DeepGEMM functions
|
||||||
import deep_gemm
|
import deep_gemm
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||||
|
|
||||||
# Import vLLM functions
|
# Import vLLM functions
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
# Copied from
|
# Copied from
|
||||||
|
|||||||
@ -5,11 +5,11 @@ import random
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
|
||||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||||
get_mla_metadata,
|
get_mla_metadata,
|
||||||
is_flashmla_supported)
|
is_flashmla_supported)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||||
|
|||||||
92
tests/test_triton_utils.py
Normal file
92
tests/test_triton_utils.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from vllm.triton_utils.importing import (TritonLanguagePlaceholder,
|
||||||
|
TritonPlaceholder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_placeholder_is_module():
|
||||||
|
triton = TritonPlaceholder()
|
||||||
|
assert isinstance(triton, types.ModuleType)
|
||||||
|
assert triton.__name__ == "triton"
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_language_placeholder_is_module():
|
||||||
|
triton_language = TritonLanguagePlaceholder()
|
||||||
|
assert isinstance(triton_language, types.ModuleType)
|
||||||
|
assert triton_language.__name__ == "triton.language"
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_placeholder_decorators():
|
||||||
|
triton = TritonPlaceholder()
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def foo(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@triton.autotune
|
||||||
|
def bar(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@triton.heuristics
|
||||||
|
def baz(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert foo(1) == 1
|
||||||
|
assert bar(2) == 2
|
||||||
|
assert baz(3) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_placeholder_decorators_with_args():
|
||||||
|
triton = TritonPlaceholder()
|
||||||
|
|
||||||
|
@triton.jit(debug=True)
|
||||||
|
def foo(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@triton.autotune(configs=[], key="x")
|
||||||
|
def bar(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{"BLOCK_SIZE": lambda args: 128 if args["x"] > 1024 else 64})
|
||||||
|
def baz(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert foo(1) == 1
|
||||||
|
assert bar(2) == 2
|
||||||
|
assert baz(3) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_placeholder_language():
|
||||||
|
lang = TritonLanguagePlaceholder()
|
||||||
|
assert isinstance(lang, types.ModuleType)
|
||||||
|
assert lang.__name__ == "triton.language"
|
||||||
|
assert lang.constexpr is None
|
||||||
|
assert lang.dtype is None
|
||||||
|
assert lang.int64 is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_triton_placeholder_language_from_parent():
|
||||||
|
triton = TritonPlaceholder()
|
||||||
|
lang = triton.language
|
||||||
|
assert isinstance(lang, TritonLanguagePlaceholder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_triton_fallback():
|
||||||
|
# clear existing triton modules
|
||||||
|
sys.modules.pop("triton", None)
|
||||||
|
sys.modules.pop("triton.language", None)
|
||||||
|
sys.modules.pop("vllm.triton_utils", None)
|
||||||
|
sys.modules.pop("vllm.triton_utils.importing", None)
|
||||||
|
|
||||||
|
# mock triton not being installed
|
||||||
|
with mock.patch.dict(sys.modules, {"triton": None}):
|
||||||
|
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||||
|
assert HAS_TRITON is False
|
||||||
|
assert triton.__class__.__name__ == "TritonPlaceholder"
|
||||||
|
assert triton.language.__class__.__name__ == "TritonLanguagePlaceholder"
|
||||||
|
assert tl.__class__.__name__ == "TritonLanguagePlaceholder"
|
||||||
@ -1,8 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
def blocksparse_flash_attn_varlen_fwd(
|
def blocksparse_flash_attn_varlen_fwd(
|
||||||
|
|||||||
@ -8,7 +8,8 @@ from functools import lru_cache
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
class csr_matrix:
|
class csr_matrix:
|
||||||
|
|||||||
@ -7,11 +7,10 @@
|
|||||||
# - Thomas Parnell <tpa@zurich.ibm.com>
|
# - Thomas Parnell <tpa@zurich.ibm.com>
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
from vllm.platforms.rocm import use_rocm_custom_paged_attention
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .prefix_prefill import context_attention_fwd
|
from .prefix_prefill import context_attention_fwd
|
||||||
|
|
||||||
|
|||||||
@ -4,10 +4,9 @@
|
|||||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
# Static kernels parameters
|
# Static kernels parameters
|
||||||
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
|
||||||
|
|||||||
@ -30,10 +30,8 @@ It supports page size >= 1.
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
is_hip_ = current_platform.is_rocm()
|
is_hip_ = current_platform.is_rocm()
|
||||||
|
|
||||||
|
|||||||
@ -25,11 +25,10 @@ Currently only the forward kernel is supported, and contains these features:
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
|
SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd']
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
|
|||||||
@ -2,8 +2,7 @@
|
|||||||
"""
|
"""
|
||||||
Utilities for Punica kernel construction.
|
Utilities for Punica kernel construction.
|
||||||
"""
|
"""
|
||||||
import triton
|
from vllm.triton_utils import tl, triton
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
|
|||||||
@ -6,8 +6,6 @@ import os
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@ -21,6 +19,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||||
per_token_group_quant_int8, per_token_quant_int8)
|
per_token_group_quant_int8, per_token_quant_int8)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||||
|
|||||||
@ -2,11 +2,10 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import round_up
|
from vllm.utils import round_up
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
def _fwd_diag_kernel(Q, K, V, Out, S, b: tl.constexpr, h: tl.constexpr, n,
|
||||||
|
|||||||
@ -4,13 +4,11 @@
|
|||||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||||
|
|
||||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
|
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
|
||||||
>= version.parse("3.0.0"))
|
>= version.parse("3.0.0"))
|
||||||
|
|||||||
@ -8,8 +8,8 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
|
|||||||
@ -6,10 +6,10 @@
|
|||||||
# ruff: noqa: E501,SIM102
|
# ruff: noqa: E501,SIM102
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,8 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .mamba_ssm import softplus
|
from .mamba_ssm import softplus
|
||||||
|
|
||||||
|
|||||||
@ -6,10 +6,11 @@
|
|||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
from .ssd_bmm import _bmm_chunk_fwd
|
from .ssd_bmm import _bmm_chunk_fwd
|
||||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||||
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
||||||
|
|||||||
@ -6,8 +6,8 @@
|
|||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,8 @@
|
|||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
|
||||||
def is_weak_contiguous(x: torch.Tensor):
|
def is_weak_contiguous(x: torch.Tensor):
|
||||||
|
|||||||
@ -7,8 +7,6 @@ import os
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -17,6 +15,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -8,10 +8,9 @@ import os
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from vllm.triton_utils.importing import HAS_TRITON
|
from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
|
||||||
|
TritonPlaceholder)
|
||||||
|
|
||||||
__all__ = ["HAS_TRITON"]
|
if HAS_TRITON:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
else:
|
||||||
|
triton = TritonPlaceholder()
|
||||||
|
tl = TritonLanguagePlaceholder()
|
||||||
|
|
||||||
|
__all__ = ["HAS_TRITON", "triton", "tl"]
|
||||||
|
|||||||
@ -16,6 +16,7 @@ if not HAS_TRITON:
|
|||||||
logger.info("Triton not installed or not compatible; certain GPU-related"
|
logger.info("Triton not installed or not compatible; certain GPU-related"
|
||||||
" functions will not be available.")
|
" functions will not be available.")
|
||||||
|
|
||||||
|
|
||||||
class TritonPlaceholder(types.ModuleType):
|
class TritonPlaceholder(types.ModuleType):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -31,13 +32,14 @@ if not HAS_TRITON:
|
|||||||
|
|
||||||
def _dummy_decorator(self, name):
|
def _dummy_decorator(self, name):
|
||||||
|
|
||||||
def decorator(func=None, **kwargs):
|
def decorator(*args, **kwargs):
|
||||||
if func is None:
|
if args and callable(args[0]):
|
||||||
|
return args[0]
|
||||||
return lambda f: f
|
return lambda f: f
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class TritonLanguagePlaceholder(types.ModuleType):
|
class TritonLanguagePlaceholder(types.ModuleType):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@ -3,10 +3,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
@ -11,6 +9,7 @@ from vllm.model_executor.model_loader.loader import get_model_loader
|
|||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user