mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:45:00 +08:00
[FEAT] [ROCm] Upgrade AITER Fused MoE kernels. (#18271)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
b50602d5f0
commit
d260f799a9
@ -419,10 +419,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
shuffle_weights)
|
shuffle_weights)
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
# use 2stage ck moe layout
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
layer.w2_weight.data,
|
|
||||||
layout=(32, 32))
|
|
||||||
|
|
||||||
layer.w13_weight.data = shuffled_w13
|
layer.w13_weight.data = shuffled_w13
|
||||||
layer.w2_weight.data = shuffled_w2
|
layer.w2_weight.data = shuffled_w2
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from enum import IntEnum
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -9,6 +10,28 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
class QuantMethod(IntEnum):
|
||||||
|
# This allows interfacing with AITER QuantType Enum
|
||||||
|
# without importing the QuantType from AITER globally.
|
||||||
|
|
||||||
|
# Note that these quantization methods are
|
||||||
|
# supported in AITER package. However,
|
||||||
|
# not all are used in this module.
|
||||||
|
|
||||||
|
NO = 0 # a16w16
|
||||||
|
PER_TENSOR = 1 # w8a8 (pre_Tensor)
|
||||||
|
PER_TOKEN = 2 # w8a8/w8a4 (per_Token)
|
||||||
|
BLOCK_1X128 = 3 # block quantized w8a8 (per_1x128)
|
||||||
|
BLOCK_128x128 = 4 # block quantized w8a8 (per_128x128)
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationMethod(IntEnum):
|
||||||
|
# This allows interfacing with AITER ActivationType enum
|
||||||
|
# without importing the ActivationType enum from AITER globally.
|
||||||
|
SILU = 0
|
||||||
|
GELU = 1
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def is_rocm_aiter_moe_enabled() -> bool:
|
def is_rocm_aiter_moe_enabled() -> bool:
|
||||||
return current_platform.is_rocm() \
|
return current_platform.is_rocm() \
|
||||||
@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl(
|
|||||||
a16: bool = False,
|
a16: bool = False,
|
||||||
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||||
expert_mask: Optional[torch.Tensor] = None,
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
activation_str: str = "silu") -> torch.Tensor:
|
activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor:
|
||||||
|
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||||
|
|
||||||
activation = \
|
activation = ActivationType(activation_method)
|
||||||
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu
|
|
||||||
|
|
||||||
return asm_moe_tkw1(hidden_states,
|
return asm_moe_tkw1(hidden_states,
|
||||||
w1,
|
w1,
|
||||||
@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
|
|||||||
a16: bool = False,
|
a16: bool = False,
|
||||||
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||||
expert_mask: Optional[torch.Tensor] = None,
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
activation_str: str = "silu") -> torch.Tensor:
|
activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
hidden_states_dtype: torch.dtype,
|
|
||||||
expert_mask: torch.Tensor,
|
|
||||||
a1: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
a1_scale: torch.Tensor,
|
|
||||||
block_shape: list[int],
|
|
||||||
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
from aiter import fmoe_fp8_blockscale_g1u1
|
|
||||||
from aiter.fused_moe_bf16_asm import moe_sorting_ck
|
|
||||||
|
|
||||||
topk = topk_ids.shape[1]
|
|
||||||
model_dim = w1.shape[-1]
|
|
||||||
local_E = E = w1.shape[0]
|
|
||||||
if expert_mask is not None:
|
|
||||||
E = expert_mask.numel()
|
|
||||||
|
|
||||||
(
|
|
||||||
sorted_token_ids,
|
|
||||||
sorted_weight_buf,
|
|
||||||
sorted_expert_ids,
|
|
||||||
num_valid_ids,
|
|
||||||
out_asm,
|
|
||||||
) = moe_sorting_ck(topk_ids,
|
|
||||||
topk_weights,
|
|
||||||
E,
|
|
||||||
model_dim,
|
|
||||||
hidden_states_dtype,
|
|
||||||
expert_mask=expert_mask)
|
|
||||||
|
|
||||||
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
|
|
||||||
sorted_weight_buf, sorted_expert_ids,
|
|
||||||
num_valid_ids, topk,
|
|
||||||
a1_scale.t().contiguous(),
|
|
||||||
w1_scale.view(local_E, -1),
|
|
||||||
w2_scale.view(local_E,
|
|
||||||
-1), *block_shape, smooth_scale)
|
|
||||||
|
|
||||||
return out_asm
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
hidden_states_dtype: torch.dtype,
|
|
||||||
expert_mask: torch.Tensor,
|
|
||||||
a1: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
w1_scale: torch.Tensor,
|
|
||||||
w2_scale: torch.Tensor,
|
|
||||||
a1_scale: torch.Tensor,
|
|
||||||
block_shape: list[int],
|
|
||||||
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
|
|
||||||
return torch.empty_like(a1, dtype=hidden_states_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
|
||||||
a16: bool = False,
|
|
||||||
activation: str = "silu") -> torch.Tensor:
|
|
||||||
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
|
|
||||||
from aiter import ActivationType
|
|
||||||
|
|
||||||
assert activation in ["silu", "gelu"], "The given activation:" \
|
|
||||||
f" {activation}" \
|
|
||||||
" is not supported in" \
|
|
||||||
" AITER."
|
|
||||||
if activation == "silu":
|
|
||||||
aiter_activation = ActivationType.Silu
|
|
||||||
else:
|
|
||||||
aiter_activation = ActivationType.Gelu
|
|
||||||
|
|
||||||
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
topk_weight=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
fc1_scale=fc1_scale,
|
|
||||||
fc2_scale=fc2_scale,
|
|
||||||
fc1_smooth_scale=fc1_smooth_scale,
|
|
||||||
fc2_smooth_scale=fc2_smooth_scale,
|
|
||||||
a16=a16,
|
|
||||||
activation=aiter_activation)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
|
||||||
a16: bool = False,
|
|
||||||
activation: str = "silu") -> torch.Tensor:
|
|
||||||
return torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_ck_moe_2stages_impl(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_size: Optional[list[int]] = None,
|
|
||||||
expert_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
|
||||||
return ck_moe_2stages(a1=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
topk_weight=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
fc1_scale=fc1_scale,
|
|
||||||
fc2_scale=fc2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
block_size=block_size,
|
|
||||||
expert_mask=expert_mask)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_ck_moe_2stages_fake(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_size: Optional[list[int]] = None,
|
|
||||||
expert_mask: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_fused_moe_impl(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
activation_method: int = ActivationMethod.SILU.value,
|
||||||
|
quant_method: int = QuantMethod.NO.value,
|
||||||
|
doweight_stage1: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter import ActivationType, QuantType
|
||||||
|
from aiter.fused_moe import fused_moe
|
||||||
|
|
||||||
|
activation = ActivationType(activation_method)
|
||||||
|
quant_type = QuantType(quant_method)
|
||||||
|
|
||||||
|
return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask,
|
||||||
|
activation, quant_type, doweight_stage1, w1_scale,
|
||||||
|
w2_scale, a1_scale, a2_scale)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_fused_moe_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
activation_method: int = ActivationMethod.SILU.value,
|
||||||
|
quant_method: int = QuantMethod.NO.value,
|
||||||
|
doweight_stage1: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
@ -285,26 +195,10 @@ if current_platform.is_rocm():
|
|||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
|
op_name="rocm_aiter_fused_moe",
|
||||||
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
|
op_func=rocm_aiter_fused_moe_impl,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake,
|
fake_impl=rocm_aiter_fused_moe_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_asm_moe",
|
|
||||||
op_func=rocm_aiter_asm_moe_impl,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_asm_moe_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_ck_moe_2stages",
|
|
||||||
op_func=rocm_aiter_ck_moe_2stages_impl,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_ck_moe_2stages_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -373,32 +267,14 @@ def rocm_aiter_fused_experts(
|
|||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
activation_method = (ActivationMethod.SILU
|
||||||
per_token_group_quant_fp8)
|
if activation == "silu" else ActivationMethod.GELU)
|
||||||
|
|
||||||
# All AITER Fused MoE kernels are expecting the following datatypes
|
# All AITER Fused MoE kernels are expecting the following datatypes
|
||||||
topk_weights = topk_weights.to(torch.float32)
|
topk_weights = topk_weights.to(torch.float32)
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
|
||||||
# w8a8 block-scaled
|
|
||||||
if block_shape is not None and use_fp8_w8a8:
|
|
||||||
assert not apply_router_weight_on_input, (
|
|
||||||
"apply_router_weight_on_input is not supported for block scaled moe"
|
|
||||||
)
|
|
||||||
assert w1_scale is not None
|
|
||||||
assert w2_scale is not None
|
|
||||||
|
|
||||||
# The default block sizes are 128 in AITER.
|
|
||||||
block_shape = [128, 128] if block_shape is None else block_shape
|
|
||||||
|
|
||||||
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
|
|
||||||
|
|
||||||
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
|
|
||||||
topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2,
|
|
||||||
w1_scale, w2_scale, a1_scale, block_shape, None)
|
|
||||||
|
|
||||||
# w8a8 per-channel quantization
|
# w8a8 per-channel quantization
|
||||||
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||||
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||||
# This applies topk_weights on the GEMM output of the first FC layer
|
# This applies topk_weights on the GEMM output of the first FC layer
|
||||||
# rather than the second FC.
|
# rather than the second FC.
|
||||||
@ -421,60 +297,44 @@ def rocm_aiter_fused_experts(
|
|||||||
a16=False,
|
a16=False,
|
||||||
per_tensor_quant_scale=None,
|
per_tensor_quant_scale=None,
|
||||||
expert_mask=None,
|
expert_mask=None,
|
||||||
activation_str=activation)
|
activation_method=activation_method)
|
||||||
|
|
||||||
# w8a8 per-tensor activation per-tensor weight
|
else:
|
||||||
elif use_fp8_w8a8:
|
quant_method = QuantMethod.NO.value
|
||||||
assert not apply_router_weight_on_input, (
|
|
||||||
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
|
||||||
|
|
||||||
# - faster static per-tensor-activation static per-tensor-weight
|
# w8a8 block-scaled
|
||||||
# fp8 quantization w8a8
|
if block_shape is not None and use_fp8_w8a8:
|
||||||
if a1_scale is not None and a2_scale is not None:
|
assert not apply_router_weight_on_input, (
|
||||||
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
|
"apply_router_weight_on_input is\
|
||||||
hidden_states=hidden_states,
|
not supported for block scaled moe")
|
||||||
w1=w1,
|
assert w1_scale is not None
|
||||||
w2=w2,
|
assert w2_scale is not None
|
||||||
topk_weights=topk_weights,
|
quant_method = QuantMethod.BLOCK_128x128.value
|
||||||
topk_ids=topk_ids,
|
elif use_fp8_w8a8:
|
||||||
fc1_scale=w1_scale,
|
# Currently only per tensor quantization method is enabled.
|
||||||
fc2_scale=w2_scale,
|
quant_method = QuantMethod.PER_TENSOR.value
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale)
|
|
||||||
|
|
||||||
# - fallback static per-tensor-activation static per-tensor-weight
|
if apply_router_weight_on_input:
|
||||||
# fp8 quantization w8a8
|
assert (topk_weights.dim() == 2
|
||||||
# - dynamic per-tensor activation static per-tensor-weight
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
# fp8 quantization w8a8
|
_, topk = topk_weights.shape
|
||||||
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
|
assert (
|
||||||
w1=w1,
|
topk == 1
|
||||||
w2=w2,
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
fc1_scale=w1_scale,
|
|
||||||
fc2_scale=w2_scale,
|
|
||||||
fc1_smooth_scale=None,
|
|
||||||
fc2_smooth_scale=None,
|
|
||||||
a16=False,
|
|
||||||
activation=activation)
|
|
||||||
if apply_router_weight_on_input:
|
|
||||||
assert (topk_weights.dim() == 2
|
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
|
||||||
_, topk = topk_weights.shape
|
|
||||||
assert (
|
|
||||||
topk == 1
|
|
||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
|
||||||
|
|
||||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
hidden_states,
|
||||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
w1,
|
||||||
|
w2,
|
||||||
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
|
topk_weights,
|
||||||
hidden_states=hidden_states,
|
topk_ids,
|
||||||
w1=w1,
|
quant_method=quant_method,
|
||||||
w2=w2,
|
activation_method=activation_method,
|
||||||
topk_weights=topk_weights,
|
w1_scale=w1_scale,
|
||||||
topk_ids=topk_ids)
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
doweight_stage1=apply_router_weight_on_input)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
||||||
@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
|||||||
return topk_weights, topk_indices
|
return topk_weights, topk_indices
|
||||||
|
|
||||||
|
|
||||||
def shuffle_weights(*tensors: torch.Tensor,
|
def shuffle_weights(
|
||||||
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
|
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
"""
|
"""
|
||||||
Applies shuffle_weight function from AITER to each
|
Applies shuffle_weight function from AITER to each
|
||||||
input tensor and returns them.
|
input tensor and returns them.
|
||||||
|
|
||||||
|
Rearranges (shuffles) the input tensor/s
|
||||||
|
into a specified block layout for optimized computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*tensors: Variable number of torch.Tensor objects.
|
*tensors: Variable number of torch.Tensor objects.
|
||||||
|
layout: A pair of integers specifying the
|
||||||
|
block sizes used to divide the tensors during shuffling.
|
||||||
|
Default is (16, 16).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Tuple of shuffled tensors.
|
A Tuple of shuffled tensors.
|
||||||
@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor,
|
|||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||||
|
|
||||||
|
|
||||||
def expand_weights(*tensors: torch.Tensor,
|
|
||||||
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
|
|
||||||
"""
|
|
||||||
Expands the dimensions of input tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*tensors: A variable number of torch.Tensor objects.
|
|
||||||
expansion_dims: A list of expansion dimensions
|
|
||||||
corresponding to each tensor.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Tuple of tensors with expanded dimensions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert len(tensors) == len(expansion_dims), \
|
|
||||||
"Number of tensors must match the number of expansion dimensions."
|
|
||||||
|
|
||||||
return tuple(
|
|
||||||
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
|
|
||||||
for tensor, dim in zip(tensors, expansion_dims))
|
|
||||||
@ -286,9 +286,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
rocm_aiter_fused_experts, shuffle_weights)
|
rocm_aiter_fused_experts, shuffle_weights)
|
||||||
|
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
layer.w2_weight.data,
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
layout=(16, 16))
|
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|||||||
@ -595,7 +595,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Lazy import to avoid importing triton too early.
|
# Lazy import to avoid importing triton too early.
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||||
|
|
||||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
|
|
||||||
@ -627,9 +627,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
layer.w13_weight.data,
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
layer.w2_weight.data,
|
|
||||||
layout=(16, 16))
|
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -675,20 +673,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
w13_scales, w2_scales = expand_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
layer.w13_weight_scale.data,
|
layer.w13_weight, layer.w2_weight)
|
||||||
layer.w2_weight_scale.data,
|
|
||||||
expansion_dims=[
|
|
||||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
|
||||||
])
|
|
||||||
layer.w13_weight_scale = torch.nn.Parameter(
|
|
||||||
w13_scales.contiguous(), requires_grad=False)
|
|
||||||
layer.w2_weight_scale = torch.nn.Parameter(
|
|
||||||
w2_scales.contiguous(), requires_grad=False)
|
|
||||||
|
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
layout=(16, 16))
|
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -760,20 +746,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
expansion_dims = [
|
layer.w13_weight, layer.w2_weight)
|
||||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
|
||||||
]
|
|
||||||
max_w13_scales, w2_scales = expand_weights(
|
|
||||||
max_w13_scales,
|
|
||||||
layer.w2_weight_scale.data,
|
|
||||||
expansion_dims=expansion_dims)
|
|
||||||
layer.w2_weight_scale = torch.nn.Parameter(
|
|
||||||
w2_scales.contiguous(), requires_grad=False)
|
|
||||||
|
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
layout=(32, 32))
|
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user