mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 14:32:28 +08:00
[TPU] Avoid Triton Import (#15589)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
parent
df8d3d1287
commit
e1e0fd7543
@ -16,8 +16,6 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -119,7 +117,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
|
||||
layer.w2_weight.data),
|
||||
requires_grad=False)
|
||||
|
||||
# Lazy import to avoid importing triton.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
|
||||
@ -13,9 +13,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -532,6 +529,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Lazy import to avoid importing triton too early.
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||
|
||||
# TODO (rob): refactor block quant into separate class.
|
||||
if self.block_quant:
|
||||
assert self.quant_config.activation_scheme == "dynamic"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user