diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 711bdfd688501..750c5f731c7c6 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -16,8 +16,6 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - is_rocm_aiter_moe_enabled, shuffle_weights) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs @@ -119,7 +117,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight.data), requires_grad=False) - + # Lazy import to avoid importing triton. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled, shuffle_weights) if is_rocm_aiter_moe_enabled(): # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index bc17a569da2c3..f3907b4784b54 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -13,9 +13,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - expand_weights, is_rocm_aiter_block_scaled_moe_enabled, - is_rocm_aiter_moe_enabled, shuffle_weights) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -532,6 +529,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: + # Lazy import to avoid importing triton too early. + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + expand_weights, is_rocm_aiter_block_scaled_moe_enabled, + is_rocm_aiter_moe_enabled, shuffle_weights) + # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic"