diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 4482029c16a87..6d6a2e22bc5f3 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch @@ -8,13 +8,16 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import extract_required_args +from vllm.utils import has_triton_kernels -if True: +if has_triton_kernels(): import triton_kernels.swiglu - from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation, - PrecisionConfig, matmul_ogs) + from triton_kernels.matmul_ogs import FnSpecs, FusedActivation, matmul_ogs from triton_kernels.routing import routing +if TYPE_CHECKING: + from triton_kernels.matmul_ogs import PrecisionConfig + def triton_kernel_moe_forward( hidden_states: torch.Tensor, @@ -33,8 +36,8 @@ def triton_kernel_moe_forward( w2_scale: Optional[torch.Tensor] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, - w1_precision=None, # PrecisionConfig or None - w2_precision=None, # PrecisionConfig or None + w1_precision: Optional["PrecisionConfig"] = None, + w2_precision: Optional["PrecisionConfig"] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, @@ -90,8 +93,8 @@ def triton_kernel_fused_experts( w2_scale: Optional[torch.Tensor] = None, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, - w1_precision=None, # PrecisionConfig or None - w2_precision=None, # PrecisionConfig or None + w1_precision: Optional["PrecisionConfig"] = None, + w2_precision: Optional["PrecisionConfig"] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, @@ -141,8 +144,14 @@ def triton_kernel_fused_experts( class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__(self, quant_config, max_num_tokens: int, num_dispatchers: int, - w1_precision: PrecisionConfig, w2_precision: PrecisionConfig): + def __init__( + self, + quant_config, + max_num_tokens: int, + num_dispatchers: int, + w1_precision: "PrecisionConfig", + w2_precision: "PrecisionConfig", + ): super().__init__(quant_config) self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers