mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 22:37:14 +08:00
[ROCm] Add aiter tkw1 kernel for Llama4 fp8 (#16727)
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
0e4254492f
commit
5b794cae8d
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
|||||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||||
ARG FA_BRANCH="1a7f4dfa"
|
ARG FA_BRANCH="1a7f4dfa"
|
||||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||||
ARG AITER_BRANCH="8970b25b"
|
ARG AITER_BRANCH="5a77249"
|
||||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||||
|
|
||||||
FROM ${BASE_IMAGE} AS base
|
FROM ${BASE_IMAGE} AS base
|
||||||
|
|||||||
@ -77,7 +77,6 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_USE_AITER: bool = False
|
VLLM_ROCM_USE_AITER: bool = False
|
||||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||||
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
|
|
||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
VLLM_ROCM_FP8_PADDING: bool = True
|
VLLM_ROCM_FP8_PADDING: bool = True
|
||||||
VLLM_ROCM_MOE_PADDING: bool = True
|
VLLM_ROCM_MOE_PADDING: bool = True
|
||||||
@ -546,13 +545,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
# Whether to use aiter block scaled moe kernel.
|
|
||||||
# By default this is disabled.
|
|
||||||
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE":
|
|
||||||
lambda:
|
|
||||||
(os.getenv("VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE", "false").lower() in
|
|
||||||
("true", "1")),
|
|
||||||
|
|
||||||
# use aiter rms norm op if aiter ops are enabled.
|
# use aiter rms norm op if aiter ops are enabled.
|
||||||
"VLLM_ROCM_USE_AITER_RMSNORM":
|
"VLLM_ROCM_USE_AITER_RMSNORM":
|
||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in
|
||||||
|
|||||||
@ -23,9 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
from .rocm_aiter_fused_moe import (is_rocm_aiter_moe_enabled,
|
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
|
||||||
rocm_aiter_fused_experts,
|
|
||||||
rocm_aiter_topk_softmax)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -846,6 +844,7 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
|
|||||||
|
|
||||||
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
|
||||||
if is_rocm_aiter_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
|
||||||
return rocm_aiter_topk_softmax
|
return rocm_aiter_topk_softmax
|
||||||
return vllm_topk_softmax
|
return vllm_topk_softmax
|
||||||
|
|
||||||
@ -1102,6 +1101,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
|
|||||||
|
|
||||||
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
|
||||||
if is_rocm_aiter_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
||||||
return rocm_aiter_fused_experts
|
return rocm_aiter_fused_experts
|
||||||
if inplace:
|
if inplace:
|
||||||
return torch_vllm_inplace_fused_experts
|
return torch_vllm_inplace_fused_experts
|
||||||
|
|||||||
@ -10,28 +10,68 @@ from vllm.platforms import current_platform
|
|||||||
def is_rocm_aiter_moe_enabled() -> bool:
|
def is_rocm_aiter_moe_enabled() -> bool:
|
||||||
return current_platform.is_rocm() \
|
return current_platform.is_rocm() \
|
||||||
and envs.VLLM_ROCM_USE_AITER_MOE \
|
and envs.VLLM_ROCM_USE_AITER_MOE \
|
||||||
and envs.VLLM_ROCM_USE_AITER \
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
|
||||||
|
|
||||||
def is_rocm_aiter_block_scaled_moe_enabled() -> bool:
|
def rocm_aiter_asm_moe_tkw1(hidden_states,
|
||||||
return is_rocm_aiter_moe_enabled() and \
|
w1,
|
||||||
envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale=None,
|
||||||
|
fc2_scale=None,
|
||||||
|
fc1_smooth_scale=None,
|
||||||
|
fc2_smooth_scale=None,
|
||||||
|
a16=False,
|
||||||
|
per_tensor_quant_scale=None,
|
||||||
|
expert_mask=None,
|
||||||
|
activation_str: str = "silu") -> None:
|
||||||
|
|
||||||
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||||
|
|
||||||
|
activation = \
|
||||||
|
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu
|
||||||
|
|
||||||
|
return asm_moe_tkw1(hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weight,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale=fc1_scale,
|
||||||
|
fc2_scale=fc2_scale,
|
||||||
|
fc1_smooth_scale=fc1_smooth_scale,
|
||||||
|
fc2_smooth_scale=fc2_smooth_scale,
|
||||||
|
a16=a16,
|
||||||
|
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||||
|
expert_mask=expert_mask,
|
||||||
|
activation=activation)
|
||||||
|
|
||||||
|
|
||||||
def rocm_aiter_fused_experts(
|
def rocm_aiter_fused_experts(
|
||||||
*,
|
hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
inplace: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
activation: str = "silu",
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
use_fp8_w8a8: bool = False,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
use_int8_w8a8: bool = False,
|
||||||
block_shape: Optional[List[int]] = None,
|
use_int8_w8a16: bool = False,
|
||||||
expert_mask: Optional[torch.Tensor] = None,
|
use_int4_w4a16: bool = False,
|
||||||
**kwagrs # Ignore additional keyword arguments
|
per_channel_quant: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[List[int]] = None,
|
||||||
|
allow_deep_gemm: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
import aiter as rocm_aiter
|
import aiter as rocm_aiter
|
||||||
@ -40,25 +80,21 @@ def rocm_aiter_fused_experts(
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8)
|
per_token_group_quant_fp8)
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
# All AITER Fused MoE kernels are expecting the following datatypes
|
||||||
assert (topk_weights.dim() == 2
|
topk_weights = topk_weights.to(torch.float32)
|
||||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
_, topk = topk_weights.shape
|
|
||||||
assert (
|
|
||||||
topk == 1
|
|
||||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
|
||||||
|
|
||||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
if (block_shape is not None) and use_fp8_w8a8:
|
||||||
topk_ids = topk_ids.to(torch.int32)
|
assert not apply_router_weight_on_input, (
|
||||||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
"apply_router_weight_on_input is not supported for block scaled moe"
|
||||||
|
)
|
||||||
|
|
||||||
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
|
|
||||||
assert w1_scale is not None
|
assert w1_scale is not None
|
||||||
assert w2_scale is not None
|
assert w2_scale is not None
|
||||||
|
|
||||||
local_E = E = w1.shape[0]
|
local_E = E = w1.shape[0]
|
||||||
if expert_mask is not None:
|
if expert_map is not None:
|
||||||
E = expert_mask.numel()
|
E = expert_map.numel()
|
||||||
|
|
||||||
topk = topk_ids.shape[1]
|
topk = topk_ids.shape[1]
|
||||||
model_dim = w1.shape[-1]
|
model_dim = w1.shape[-1]
|
||||||
@ -80,7 +116,7 @@ def rocm_aiter_fused_experts(
|
|||||||
E,
|
E,
|
||||||
model_dim,
|
model_dim,
|
||||||
dtype,
|
dtype,
|
||||||
expert_mask=expert_mask)
|
expert_mask=expert_map)
|
||||||
|
|
||||||
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
|
a1, a1_scale = per_token_group_quant_fp8(hidden_states, scale_blk_k)
|
||||||
rocm_aiter.fmoe_fp8_blockscale_g1u1(
|
rocm_aiter.fmoe_fp8_blockscale_g1u1(
|
||||||
@ -102,7 +138,33 @@ def rocm_aiter_fused_experts(
|
|||||||
)
|
)
|
||||||
return out_asm
|
return out_asm
|
||||||
|
|
||||||
|
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||||
|
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||||
|
# This applies topk_weights on the GEMM output of the first FC layer
|
||||||
|
# rather than the second FC.
|
||||||
|
assert (topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
assert topk_weights.shape[-1] == 1, (
|
||||||
|
"Only support topk=1 when"
|
||||||
|
" `apply_router_weight_on_input` is True")
|
||||||
|
|
||||||
|
return rocm_aiter_asm_moe_tkw1(hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale=w1_scale,
|
||||||
|
fc2_scale=w2_scale,
|
||||||
|
fc1_smooth_scale=None,
|
||||||
|
fc2_smooth_scale=None,
|
||||||
|
a16=False,
|
||||||
|
per_tensor_quant_scale=None,
|
||||||
|
expert_mask=expert_map,
|
||||||
|
activation_str=activation)
|
||||||
|
|
||||||
elif use_fp8_w8a8:
|
elif use_fp8_w8a8:
|
||||||
|
assert not apply_router_weight_on_input, (
|
||||||
|
"apply_router_weight_on_input is not supported for fp8_w8a8")
|
||||||
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
@ -114,6 +176,18 @@ def rocm_aiter_fused_experts(
|
|||||||
fc2_smooth_scale=None,
|
fc2_smooth_scale=None,
|
||||||
a16=False)
|
a16=False)
|
||||||
|
|
||||||
|
if apply_router_weight_on_input:
|
||||||
|
assert (topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
_, topk = topk_weights.shape
|
||||||
|
assert (
|
||||||
|
topk == 1
|
||||||
|
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||||
|
|
||||||
|
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
|
||||||
|
|
||||||
return rocm_aiter.ck_moe(hidden_states=hidden_states,
|
return rocm_aiter.ck_moe(hidden_states=hidden_states,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
|
|||||||
@ -250,6 +250,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
|
is_rocm_aiter_moe_enabled)
|
||||||
|
|
||||||
|
# Property to determine if AITER is used
|
||||||
|
if is_rocm_aiter_moe_enabled():
|
||||||
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||||
|
rocm_aiter_fused_experts, shuffle_weights)
|
||||||
|
|
||||||
|
# reshaping weights is required for aiter moe kernel.
|
||||||
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
|
|
||||||
|
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
self.fused_experts_func = rocm_aiter_fused_experts
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
self.fused_experts_func = fused_experts
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -268,7 +290,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
activation: str = "silu",
|
activation: str = "silu",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
|
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -282,7 +303,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
return fused_experts(
|
return self.fused_experts_func(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
layer.w2_weight,
|
layer.w2_weight,
|
||||||
|
|||||||
@ -575,8 +575,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# Lazy import to avoid importing triton too early.
|
# Lazy import to avoid importing triton too early.
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||||
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
|
expand_weights, is_rocm_aiter_moe_enabled, shuffle_weights)
|
||||||
is_rocm_aiter_moe_enabled, shuffle_weights)
|
|
||||||
|
|
||||||
# TODO (rob): refactor block quant into separate class.
|
# TODO (rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
@ -603,7 +602,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||||
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
layer.w2_weight_scale_inv = Parameter(w2_weight_scale_inv,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
if is_rocm_aiter_block_scaled_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
# reshaping weights is required for aiter moe kernel.
|
# reshaping weights is required for aiter moe kernel.
|
||||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||||
layer.w13_weight.data, layer.w2_weight.data)
|
layer.w13_weight.data, layer.w2_weight.data)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user