mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support (#17110)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
38fe728d60
commit
612c2edb4f
@ -224,9 +224,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
|||||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||||
huggingface."""
|
huggingface."""
|
||||||
|
|
||||||
|
# clear the cache before every test
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
is_rocm_aiter_moe_enabled)
|
||||||
|
is_rocm_aiter_moe_enabled.cache_clear()
|
||||||
if use_rocm_aiter:
|
if use_rocm_aiter:
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
pytest.skip("AITER ROCm test skip for float32")
|
||||||
|
|
||||||
# Instantiate our and huggingface's MoE blocks
|
# Instantiate our and huggingface's MoE blocks
|
||||||
config = MixtralConfig()
|
config = MixtralConfig()
|
||||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||||
|
|||||||
@ -8,10 +8,8 @@ from vllm.model_executor.custom_op import CustomOp
|
|||||||
from vllm.model_executor.layers.activation import (GeluAndMul,
|
from vllm.model_executor.layers.activation import (GeluAndMul,
|
||||||
ReLUSquaredActivation,
|
ReLUSquaredActivation,
|
||||||
SiluAndMul)
|
SiluAndMul)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func,
|
||||||
dispatch_fused_experts_func, dispatch_topk_func,
|
vllm_topk_softmax)
|
||||||
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
|
|
||||||
vllm_topk_softmax)
|
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
is_rocm_aiter_moe_enabled)
|
is_rocm_aiter_moe_enabled)
|
||||||
from vllm.model_executor.layers.layernorm import (
|
from vllm.model_executor.layers.layernorm import (
|
||||||
@ -142,24 +140,6 @@ def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
|
|||||||
assert topk_func == vllm_topk_softmax
|
assert topk_func == vllm_topk_softmax
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
|
||||||
@pytest.mark.parametrize("inplace", [True, False])
|
|
||||||
def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
|
|
||||||
monkeypatch):
|
|
||||||
|
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
|
|
||||||
is_rocm_aiter_moe_enabled.cache_clear()
|
|
||||||
fused_experts_func = dispatch_fused_experts_func(inplace)
|
|
||||||
if current_platform.is_rocm() and int(use_rocm_aiter):
|
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
|
||||||
rocm_aiter_fused_experts)
|
|
||||||
assert fused_experts_func == rocm_aiter_fused_experts
|
|
||||||
elif inplace:
|
|
||||||
assert fused_experts_func == torch_vllm_inplace_fused_experts
|
|
||||||
else:
|
|
||||||
assert fused_experts_func == torch_vllm_outplace_fused_experts
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("add_residual", [True, False])
|
@pytest.mark.parametrize("add_residual", [True, False])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
|
||||||
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
@pytest.mark.parametrize("use_rocm_aiter_norm", ["0", "1"])
|
||||||
|
|||||||
@ -1100,9 +1100,6 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||||
if is_rocm_aiter_moe_enabled():
|
|
||||||
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
|
||||||
return rocm_aiter_fused_experts
|
|
||||||
if inplace:
|
if inplace:
|
||||||
return torch_vllm_inplace_fused_experts
|
return torch_vllm_inplace_fused_experts
|
||||||
return torch_vllm_outplace_fused_experts
|
return torch_vllm_outplace_fused_experts
|
||||||
|
|||||||
@ -84,6 +84,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||||
|
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
|
||||||
|
else:
|
||||||
|
self.rocm_aiter_fused_experts = None # type: ignore
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -126,11 +136,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||||
# Lazy import to avoid importing triton.
|
# Lazy import to avoid importing triton.
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
shuffle_weights)
|
||||||
if is_rocm_aiter_moe_enabled():
|
|
||||||
# reshaping weights is required for aiter moe kernel.
|
if self.rocm_aiter_moe_enabled:
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
# use 2stage ck moe layout
|
||||||
layer.w13_weight.data, layer.w2_weight.data)
|
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
|
||||||
|
layer.w2_weight.data,
|
||||||
|
layout=(32, 32))
|
||||||
|
|
||||||
layer.w13_weight.data = shuffled_w13
|
layer.w13_weight.data = shuffled_w13
|
||||||
layer.w2_weight.data = shuffled_w2
|
layer.w2_weight.data = shuffled_w2
|
||||||
@ -211,6 +223,16 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
return self.rocm_aiter_fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
|
|||||||
@ -20,7 +20,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
@ -40,7 +40,7 @@ def rocm_aiter_asm_moe_tkw1_impl(
|
|||||||
return asm_moe_tkw1(hidden_states,
|
return asm_moe_tkw1(hidden_states,
|
||||||
w1,
|
w1,
|
||||||
w2,
|
w2,
|
||||||
topk_weight,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
fc1_scale=fc1_scale,
|
fc1_scale=fc1_scale,
|
||||||
fc2_scale=fc2_scale,
|
fc2_scale=fc2_scale,
|
||||||
@ -56,7 +56,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
@ -69,23 +69,6 @@ def rocm_aiter_asm_moe_tkw1_fake(
|
|||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_ck_moe_impl(hidden_states: torch.Tensor, w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor, topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
from aiter import ck_moe
|
|
||||||
return ck_moe(hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_ck_moe_fake(hidden_states: torch.Tensor, w1: torch.Tensor,
|
|
||||||
w2: torch.Tensor, topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return torch.empty_like(hidden_states)
|
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
|
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
@ -152,7 +135,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
|
|||||||
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
|
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
@ -175,7 +158,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
|
|||||||
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
topk_weight=topk_weight,
|
topk_weight=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
fc1_scale=fc1_scale,
|
fc1_scale=fc1_scale,
|
||||||
fc2_scale=fc2_scale,
|
fc2_scale=fc2_scale,
|
||||||
@ -188,7 +171,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
|
|||||||
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
|
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weight: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
fc1_scale: Optional[torch.Tensor] = None,
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
fc2_scale: Optional[torch.Tensor] = None,
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
@ -199,6 +182,49 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
|
|||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_ck_moe_2stages_impl(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_size: Optional[list[int]] = None,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from aiter.fused_moe_bf16_asm import ck_moe_2stages
|
||||||
|
return ck_moe_2stages(a1=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weight=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
fc1_scale=fc1_scale,
|
||||||
|
fc2_scale=fc2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_size=block_size,
|
||||||
|
expert_mask=expert_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_ck_moe_2stages_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_size: Optional[list[int]] = None,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
|
def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor,
|
||||||
topk_indices: torch.Tensor,
|
topk_indices: torch.Tensor,
|
||||||
token_expert_indices: torch.Tensor,
|
token_expert_indices: torch.Tensor,
|
||||||
@ -258,14 +284,6 @@ if current_platform.is_rocm():
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_ck_moe",
|
|
||||||
op_func=rocm_aiter_ck_moe_impl,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=rocm_aiter_ck_moe_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
|
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1",
|
||||||
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
|
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl,
|
||||||
@ -282,6 +300,14 @@ if current_platform.is_rocm():
|
|||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_ck_moe_2stages",
|
||||||
|
op_func=rocm_aiter_ck_moe_2stages_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_ck_moe_2stages_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_topk_softmax",
|
op_name="rocm_aiter_topk_softmax",
|
||||||
op_func=rocm_aiter_topk_softmax_impl,
|
op_func=rocm_aiter_topk_softmax_impl,
|
||||||
@ -331,29 +357,21 @@ def rocm_aiter_biased_group_topk(
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
def rocm_aiter_fused_experts(
|
||||||
w1: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
inplace: bool = False,
|
topk_ids: torch.Tensor,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a8: bool = False,
|
per_channel_quant: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
use_int4_w4a16: bool = False,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
per_channel_quant: bool = False,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
global_num_experts: int = -1,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
block_shape: Optional[list[int]] = None) -> torch.Tensor:
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
|
||||||
w1_zp: Optional[torch.Tensor] = None,
|
|
||||||
w2_zp: Optional[torch.Tensor] = None,
|
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
allow_deep_gemm: bool = False) -> torch.Tensor:
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
@ -376,8 +394,8 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
|||||||
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
|
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
|
||||||
|
|
||||||
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
|
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
|
||||||
topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1,
|
topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2,
|
||||||
w2, w1_scale, w2_scale, a1_scale, block_shape, None)
|
w1_scale, w2_scale, a1_scale, block_shape, None)
|
||||||
|
|
||||||
# w8a8 per-channel quantization
|
# w8a8 per-channel quantization
|
||||||
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||||
@ -402,17 +420,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
|||||||
fc2_smooth_scale=None,
|
fc2_smooth_scale=None,
|
||||||
a16=False,
|
a16=False,
|
||||||
per_tensor_quant_scale=None,
|
per_tensor_quant_scale=None,
|
||||||
expert_mask=expert_map,
|
expert_mask=None,
|
||||||
activation_str=activation)
|
activation_str=activation)
|
||||||
|
|
||||||
# w8a8 per-tensor activation per-tensor weight
|
# w8a8 per-tensor activation per-tensor weight
|
||||||
elif use_fp8_w8a8:
|
elif use_fp8_w8a8:
|
||||||
assert not apply_router_weight_on_input, (
|
assert not apply_router_weight_on_input, (
|
||||||
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
||||||
|
|
||||||
|
# - faster static per-tensor-activation static per-tensor-weight
|
||||||
|
# fp8 quantization w8a8
|
||||||
|
if a1_scale is not None and a2_scale is not None:
|
||||||
|
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
fc1_scale=w1_scale,
|
||||||
|
fc2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale)
|
||||||
|
|
||||||
|
# - fallback static per-tensor-activation static per-tensor-weight
|
||||||
|
# fp8 quantization w8a8
|
||||||
|
# - dynamic per-tensor activation static per-tensor-weight
|
||||||
|
# fp8 quantization w8a8
|
||||||
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
|
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
topk_weight=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
fc1_scale=w1_scale,
|
fc1_scale=w1_scale,
|
||||||
fc2_scale=w2_scale,
|
fc2_scale=w2_scale,
|
||||||
@ -432,12 +469,12 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor,
|
|||||||
topk_ids = topk_ids.to(torch.int32)
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||||
|
|
||||||
# w16a16 fallback to rocm_aiter_ck_moe w16a16
|
return torch.ops.vllm.rocm_aiter_ck_moe_2stages(
|
||||||
return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids)
|
topk_ids=topk_ids)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
||||||
@ -451,7 +488,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
|
|||||||
return topk_weights, topk_indices
|
return topk_weights, topk_indices
|
||||||
|
|
||||||
|
|
||||||
def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def shuffle_weights(*tensors: torch.Tensor,
|
||||||
|
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
|
||||||
"""
|
"""
|
||||||
Applies shuffle_weight function from AITER to each
|
Applies shuffle_weight function from AITER to each
|
||||||
input tensor and returns them.
|
input tensor and returns them.
|
||||||
@ -463,7 +501,8 @@ def shuffle_weights(*tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|||||||
A Tuple of shuffled tensors.
|
A Tuple of shuffled tensors.
|
||||||
"""
|
"""
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
return tuple(shuffle_weight(tensor) for tensor in tensors)
|
|
||||||
|
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||||
|
|
||||||
|
|
||||||
def expand_weights(*tensors: torch.Tensor,
|
def expand_weights(*tensors: torch.Tensor,
|
||||||
@ -485,4 +524,4 @@ def expand_weights(*tensors: torch.Tensor,
|
|||||||
|
|
||||||
return tuple(
|
return tuple(
|
||||||
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
|
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
|
||||||
for tensor, dim in zip(tensors, expansion_dims))
|
for tensor, dim in zip(tensors, expansion_dims))
|
||||||
@ -125,6 +125,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
# Disable marlin for rocm
|
# Disable marlin for rocm
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
is_rocm_aiter_moe_enabled)
|
||||||
|
|
||||||
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -276,24 +280,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
|
||||||
is_rocm_aiter_moe_enabled)
|
|
||||||
|
|
||||||
# Property to determine if AITER is used
|
# Property to determine if AITER is used
|
||||||
if is_rocm_aiter_moe_enabled():
|
if self.rocm_aiter_moe_enabled:
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||||
rocm_aiter_fused_experts, shuffle_weights)
|
rocm_aiter_fused_experts, shuffle_weights)
|
||||||
|
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data,
|
||||||
layer.w13_weight.data, layer.w2_weight.data)
|
layer.w2_weight.data,
|
||||||
|
layout=(16, 16))
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
self.fused_experts_func = rocm_aiter_fused_experts
|
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
self.fused_experts_func = fused_experts
|
self.fused_experts_func = fused_experts
|
||||||
@ -335,6 +337,22 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
return self.rocm_aiter_fused_experts_func(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=self.weight_quant.strategy ==
|
||||||
|
QuantizationStrategy.CHANNEL,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale)
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
|||||||
@ -591,6 +591,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||||
|
|
||||||
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
|
|
||||||
# TODO (rob): refactor block quant into separate class.
|
# TODO (rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
@ -616,10 +618,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
if is_rocm_aiter_moe_enabled():
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
layer.w13_weight.data, layer.w2_weight.data)
|
layer.w13_weight.data,
|
||||||
|
layer.w2_weight.data,
|
||||||
|
layout=(16, 16))
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -663,7 +667,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
if is_rocm_aiter_moe_enabled():
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
w13_scales, w2_scales = expand_weights(
|
w13_scales, w2_scales = expand_weights(
|
||||||
layer.w13_weight_scale.data,
|
layer.w13_weight_scale.data,
|
||||||
@ -676,8 +680,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight_scale = torch.nn.Parameter(
|
layer.w2_weight_scale = torch.nn.Parameter(
|
||||||
w2_scales.contiguous(), requires_grad=False)
|
w2_scales.contiguous(), requires_grad=False)
|
||||||
|
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
||||||
layer.w13_weight, layer.w2_weight)
|
layer.w2_weight,
|
||||||
|
layout=(16, 16))
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -748,7 +753,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
dq_weight, max_w13_scales[expert_id])
|
dq_weight, max_w13_scales[expert_id])
|
||||||
start += shard_size
|
start += shard_size
|
||||||
|
|
||||||
if is_rocm_aiter_moe_enabled():
|
if self.rocm_aiter_moe_enabled:
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
expansion_dims = [
|
expansion_dims = [
|
||||||
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
layer.w13_weight.shape[1], layer.w2_weight.shape[1]
|
||||||
@ -760,8 +765,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight_scale = torch.nn.Parameter(
|
layer.w2_weight_scale = torch.nn.Parameter(
|
||||||
w2_scales.contiguous(), requires_grad=False)
|
w2_scales.contiguous(), requires_grad=False)
|
||||||
|
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight,
|
||||||
layer.w13_weight, layer.w2_weight)
|
layer.w2_weight,
|
||||||
|
layout=(32, 32))
|
||||||
|
|
||||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
@ -796,6 +802,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
rocm_aiter_fused_experts)
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -810,6 +818,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.rocm_aiter_moe_enabled:
|
||||||
|
return rocm_aiter_fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
activation=activation,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
w1_scale=(layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant else layer.w13_weight_scale),
|
||||||
|
w2_scale=(layer.w2_weight_scale_inv
|
||||||
|
if self.block_quant else layer.w2_weight_scale),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
block_shape=self.quant_config.weight_block_size)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
assert activation == "silu", (
|
assert activation == "silu", (
|
||||||
f"{activation} not supported for Marlin MoE.")
|
f"{activation} not supported for Marlin MoE.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user