mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 05:37:04 +08:00
[FEAT] [ROCm]: AITER Fused MOE V1 Support (#16752)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
0d6e187e88
commit
eef364723c
@ -11,6 +11,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|||||||
dispatch_fused_experts_func, dispatch_topk_func,
|
dispatch_fused_experts_func, dispatch_topk_func,
|
||||||
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
||||||
vllm_topk_softmax)
|
vllm_topk_softmax)
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
is_rocm_aiter_moe_enabled)
|
||||||
from vllm.model_executor.layers.layernorm import (
|
from vllm.model_executor.layers.layernorm import (
|
||||||
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
|
||||||
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
|
||||||
@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str):
|
|||||||
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||||
topk_func = dispatch_topk_func()
|
topk_func = dispatch_topk_func()
|
||||||
|
is_rocm_aiter_moe_enabled.cache_clear()
|
||||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
rocm_aiter_topk_softmax)
|
rocm_aiter_topk_softmax)
|
||||||
|
|
||||||
assert topk_func == rocm_aiter_topk_softmax
|
assert topk_func == rocm_aiter_topk_softmax
|
||||||
else:
|
else:
|
||||||
assert topk_func == vllm_topk_softmax
|
assert topk_func == vllm_topk_softmax
|
||||||
@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
|
|||||||
monkeypatch):
|
monkeypatch):
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
||||||
|
is_rocm_aiter_moe_enabled.cache_clear()
|
||||||
fused_experts_func = dispatch_fused_experts_func(inplace)
|
fused_experts_func = dispatch_fused_experts_func(inplace)
|
||||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
if current_platform.is_rocm() and int(use_rocm_aiter):
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
rocm_aiter_fused_experts)
|
rocm_aiter_fused_experts)
|
||||||
|
|
||||||
assert fused_experts_func == rocm_aiter_fused_experts
|
assert fused_experts_func == rocm_aiter_fused_experts
|
||||||
elif inplace:
|
elif inplace:
|
||||||
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
||||||
|
|||||||
@ -1,31 +1,35 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import List, Optional
|
from functools import cache
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
from vllm import envs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
def is_rocm_aiter_moe_enabled() -> bool:
|
def is_rocm_aiter_moe_enabled() -> bool:
|
||||||
return current_platform.is_rocm() \
|
return current_platform.is_rocm() \
|
||||||
and envs.VLLM_ROCM_USE_AITER_MOE \
|
and envs.VLLM_ROCM_USE_AITER_MOE \
|
||||||
and envs.VLLM_ROCM_USE_AITER
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_asm_moe_tkw1(hidden_states,
|
def rocm_aiter_asm_moe_tkw1_impl(
|
||||||
w1,
|
hidden_states: torch.Tensor,
|
||||||
w2,
|
w1: torch.Tensor,
|
||||||
topk_weight,
|
w2: torch.Tensor,
|
||||||
topk_ids,
|
topk_weight: torch.Tensor,
|
||||||
fc1_scale=None,
|
topk_ids: torch.Tensor,
|
||||||
fc2_scale=None,
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
fc1_smooth_scale=None,
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
fc2_smooth_scale=None,
|
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
a16=False,
|
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
per_tensor_quant_scale=None,
|
a16: bool = False,
|
||||||
expert_mask=None,
|
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||||
activation_str: str = "silu") -> None:
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
activation_str: str = "silu") -> torch.Tensor:
|
||||||
|
|
||||||
from aiter import ActivationType
|
from aiter import ActivationType
|
||||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||||
@ -48,34 +52,236 @@ def rocm_aiter_asm_moe_tkw1(hidden_states,
|
|||||||
activation=activation)
|
activation=activation)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fused_experts(
|
def rocm_aiter_asm_moe_tkw1_fake(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weight: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
activation: str = "silu",
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
use_fp8_w8a8: bool = False,
|
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
use_int8_w8a8: bool = False,
|
a16: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||||
use_int4_w4a16: bool = False,
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
per_channel_quant: bool = False,
|
activation_str: str = "silu") -> torch.Tensor:
|
||||||
global_num_experts: int = -1,
|
return torch.empty_like(hidden_states)
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[List[int]] = None,
|
|
||||||
allow_deep_gemm: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
import aiter as rocm_aiter
|
|
||||||
|
def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
from aiter import ck_moe
|
||||||
|
return ck_moe(hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
hidden_states_dtype: torch.dtype,
|
||||||
|
expert_mask: torch.Tensor,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
block_shape: List[int],
|
||||||
|
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
from aiter import fmoe_fp8_blockscale_g1u1
|
||||||
|
from aiter.fused_moe_bf16_asm import moe_sorting_ck
|
||||||
|
|
||||||
|
topk = topk_ids.shape[1]
|
||||||
|
model_dim = w1.shape[-1]
|
||||||
|
local_E = E = w1.shape[0]
|
||||||
|
if expert_mask is not None:
|
||||||
|
E = expert_mask.numel()
|
||||||
|
|
||||||
|
(
|
||||||
|
sorted_token_ids,
|
||||||
|
sorted_weight_buf,
|
||||||
|
sorted_expert_ids,
|
||||||
|
num_valid_ids,
|
||||||
|
out_asm,
|
||||||
|
) = moe_sorting_ck(topk_ids,
|
||||||
|
topk_weights,
|
||||||
|
E,
|
||||||
|
model_dim,
|
||||||
|
hidden_states_dtype,
|
||||||
|
expert_mask=expert_mask)
|
||||||
|
|
||||||
|
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
|
||||||
|
sorted_weight_buf, sorted_expert_ids,
|
||||||
|
num_valid_ids, topk, w1_scale.view(local_E, -1),
|
||||||
|
w2_scale.view(local_E, -1),
|
||||||
|
a1_scale.t().contiguous(), *block_shape,
|
||||||
|
smooth_scale)
|
||||||
|
|
||||||
|
return out_asm
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
hidden_states_dtype: torch.dtype,
|
||||||
|
expert_mask: torch.Tensor,
|
||||||
|
a1: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
block_shape: List[int],
|
||||||
|
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
return torch.empty_like(a1, dtype=torch.bf16)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
a16: bool = False,
|
||||||
|
activation: str = "silu") -> torch.Tensor:
|
||||||
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
|
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
|
||||||
|
from aiter import ActivationType
|
||||||
|
|
||||||
|
assert activation in ["silu", "gelu"], "The given activation:" \
|
||||||
|
f" {activation}" \
|
||||||
|
" is not supported in" \
|
||||||
|
" AITER."
|
||||||
|
if activation == "silu":
|
||||||
|
aiter_activation = ActivationType.Silu
|
||||||
|
else:
|
||||||
|
aiter_activation = ActivationType.Gelu
|
||||||
|
|
||||||
|
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weight=topk_weight,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
fc1_scale=fc1_scale,
|
||||||
|
fc2_scale=fc2_scale,
|
||||||
|
fc1_smooth_scale=fc1_smooth_scale,
|
||||||
|
fc2_smooth_scale=fc2_smooth_scale,
|
||||||
|
a16=a16,
|
||||||
|
activation=aiter_activation)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weight: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
a16: bool = False,
|
||||||
|
activation: str = "silu") -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool) -> None:
|
||||||
|
from aiter import topk_softmax
|
||||||
|
topk_softmax(topk_weights, topk_indices, token_expert_indices,
|
||||||
|
gating_output, renormalize)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor,
|
||||||
|
topk_indices: torch.Tensor,
|
||||||
|
token_expert_indices: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
renormalize: bool) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_asm_moe_tkw1",
|
||||||
|
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_ck_moe",
|
||||||
|
op_func=rocm_aiter_ck_moe_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_ck_moe_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
|
||||||
|
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_asm_moe",
|
||||||
|
op_func=rocm_aiter_asm_moe_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_asm_moe_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_topk_softmax",
|
||||||
|
op_func=rocm_aiter_topk_softmax_impl,
|
||||||
|
mutates_args=["topk_weights", "topk_indices", "token_expert_indices"],
|
||||||
|
fake_impl=rocm_aiter_topk_softmax_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
inplace: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
|
allow_deep_gemm: bool = False) -> torch.Tensor:
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
@ -84,60 +290,24 @@ def rocm_aiter_fused_experts(
|
|||||||
topk_weights = topk_weights.to(torch.float32)
|
topk_weights = topk_weights.to(torch.float32)
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
|
||||||
if (block_shape is not None) and use_fp8_w8a8:
|
# w8a8 block-scaled
|
||||||
|
if block_shape is not None and use_fp8_w8a8:
|
||||||
assert not apply_router_weight_on_input, (
|
assert not apply_router_weight_on_input, (
|
||||||
"apply_router_weight_on_input is not supported for block scaled moe"
|
"apply_router_weight_on_input is not supported for block scaled moe"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert w1_scale is not None
|
assert w1_scale is not None
|
||||||
assert w2_scale is not None
|
assert w2_scale is not None
|
||||||
|
|
||||||
local_E = E = w1.shape[0]
|
|
||||||
if expert_map is not None:
|
|
||||||
E = expert_map.numel()
|
|
||||||
|
|
||||||
topk = topk_ids.shape[1]
|
|
||||||
model_dim = w1.shape[-1]
|
|
||||||
dtype = hidden_states.dtype
|
|
||||||
# The default block sizes are 128 in AITER.
|
# The default block sizes are 128 in AITER.
|
||||||
if block_shape is None:
|
block_shape = [128, 128] if block_shape is None else block_shape
|
||||||
block_shape = [128, 128]
|
|
||||||
|
|
||||||
scale_blk_k = block_shape[1]
|
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
|
||||||
|
|
||||||
(
|
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
|
||||||
sorted_token_ids,
|
topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1,
|
||||||
sorted_weight_buf,
|
w2, w1_scale, w2_scale, a1_scale, block_shape, None)
|
||||||
sorted_expert_ids,
|
|
||||||
num_valid_ids,
|
|
||||||
out_asm,
|
|
||||||
) = rocm_aiter_asm_fmoe.moe_sorting_ck(topk_ids,
|
|
||||||
topk_weights,
|
|
||||||
E,
|
|
||||||
model_dim,
|
|
||||||
dtype,
|
|
||||||
expert_mask=expert_map)
|
|
||||||
|
|
||||||
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
|
|
||||||
rocm_aiter.fmoe_fp8_blockscale_g1u1(
|
|
||||||
out_asm,
|
|
||||||
a1,
|
|
||||||
w1,
|
|
||||||
w2,
|
|
||||||
sorted_token_ids,
|
|
||||||
sorted_weight_buf,
|
|
||||||
sorted_expert_ids,
|
|
||||||
num_valid_ids,
|
|
||||||
topk,
|
|
||||||
w1_scale.view(local_E, -1),
|
|
||||||
w2_scale.view(local_E, -1),
|
|
||||||
a1_scale.t().contiguous(),
|
|
||||||
block_shape[0],
|
|
||||||
block_shape[1],
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
return out_asm
|
|
||||||
|
|
||||||
|
# w8a8 per-channel quantization
|
||||||
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||||
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||||
# This applies topk_weights on the GEMM output of the first FC layer
|
# This applies topk_weights on the GEMM output of the first FC layer
|
||||||
@ -148,34 +318,36 @@ def rocm_aiter_fused_experts(
|
|||||||
"Only support topk=1 when"
|
"Only support topk=1 when"
|
||||||
" `apply_router_weight_on_input` is True")
|
" `apply_router_weight_on_input` is True")
|
||||||
|
|
||||||
return rocm_aiter_asm_moe_tkw1(hidden_states,
|
return torch.ops.vllm.rocm_aiter_asm_moe_tkw1(
|
||||||
w1,
|
hidden_states,
|
||||||
w2,
|
w1,
|
||||||
topk_weights,
|
w2,
|
||||||
topk_ids,
|
topk_weights,
|
||||||
fc1_scale=w1_scale,
|
topk_ids,
|
||||||
fc2_scale=w2_scale,
|
fc1_scale=w1_scale,
|
||||||
fc1_smooth_scale=None,
|
fc2_scale=w2_scale,
|
||||||
fc2_smooth_scale=None,
|
fc1_smooth_scale=None,
|
||||||
a16=False,
|
fc2_smooth_scale=None,
|
||||||
per_tensor_quant_scale=None,
|
a16=False,
|
||||||
expert_mask=expert_map,
|
per_tensor_quant_scale=None,
|
||||||
activation_str=activation)
|
expert_mask=expert_map,
|
||||||
|
activation_str=activation)
|
||||||
|
|
||||||
|
# w8a8 per-tensor activation per-tensor weight
|
||||||
elif use_fp8_w8a8:
|
elif use_fp8_w8a8:
|
||||||
assert not apply_router_weight_on_input, (
|
assert not apply_router_weight_on_input, (
|
||||||
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
||||||
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
topk_weight=topk_weights,
|
topk_weight=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
fc1_scale=w1_scale,
|
fc1_scale=w1_scale,
|
||||||
fc2_scale=w2_scale,
|
fc2_scale=w2_scale,
|
||||||
fc1_smooth_scale=None,
|
fc1_smooth_scale=None,
|
||||||
fc2_smooth_scale=None,
|
fc2_smooth_scale=None,
|
||||||
a16=False)
|
a16=False,
|
||||||
|
activation=activation)
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
assert (topk_weights.dim() == 2
|
assert (topk_weights.dim() == 2
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
@ -188,26 +360,26 @@ def rocm_aiter_fused_experts(
|
|||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||||
|
|
||||||
return rocm_aiter.ck_moe(hidden_states=hidden_states,
|
# w16a16 fallback to rocm_aiter_ck_moe w16a16
|
||||||
w1=w1,
|
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
|
||||||
w2=w2,
|
w1=w1,
|
||||||
topk_weights=topk_weights,
|
w2=w2,
|
||||||
topk_ids=topk_ids)
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
||||||
topk_indices: torch.Tensor,
|
topk_indices: torch.Tensor,
|
||||||
token_expert_indices: torch.Tensor,
|
token_expert_indices: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
renormalize: bool) -> tuple[torch.Tensor, ...]:
|
renormalize: bool) -> Tuple[torch.Tensor, ...]:
|
||||||
import aiter as rocm_aiter
|
torch.ops.vllm.rocm_aiter_topk_softmax(topk_weights, topk_indices,
|
||||||
rocm_aiter.topk_softmax(topk_weights, topk_indices, token_expert_indices,
|
token_expert_indices, gating_output,
|
||||||
gating_output, renormalize)
|
renormalize)
|
||||||
|
|
||||||
return topk_weights, topk_indices
|
return topk_weights, topk_indices
|
||||||
|
|
||||||
|
|
||||||
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]:
|
||||||
"""
|
"""
|
||||||
Applies shuffle_weight function from AITER to each
|
Applies shuffle_weight function from AITER to each
|
||||||
input tensor and returns them.
|
input tensor and returns them.
|
||||||
@ -216,15 +388,14 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|||||||
*tensors: Variable number of torch.Tensor objects.
|
*tensors: Variable number of torch.Tensor objects.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of shuffled tensors.
|
A Tuple of shuffled tensors.
|
||||||
"""
|
"""
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
return tuple(shuffle_weight(tensor) for tensor in tensors)
|
return tuple(shuffle_weight(tensor) for tensor in tensors)
|
||||||
|
|
||||||
|
|
||||||
def expand_weights(*tensors: torch.Tensor,
|
def expand_weights(*tensors: torch.Tensor,
|
||||||
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
|
expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]:
|
||||||
"""
|
"""
|
||||||
Expands the dimensions of input tensors.
|
Expands the dimensions of input tensors.
|
||||||
|
|
||||||
@ -234,7 +405,7 @@ def expand_weights(*tensors: torch.Tensor,
|
|||||||
corresponding to each tensor.
|
corresponding to each tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of tensors with expanded dimensions.
|
A Tuple of tensors with expanded dimensions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(tensors) == len(expansion_dims), \
|
assert len(tensors) == len(expansion_dims), \
|
||||||
|
|||||||
@ -304,9 +304,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return self.fused_experts_func(
|
return self.fused_experts_func(
|
||||||
x,
|
hidden_states=x,
|
||||||
layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user