mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 01:07:03 +08:00
Merge: ROCM:upstream_fp8_with_static_scales_gpt_oss
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
577128bd80
commit
7b4424b404
@ -655,6 +655,7 @@ class rocm_aiter_ops:
|
||||
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD
|
||||
)
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_CK_MXFP4_MOE = envs.VLLM_ROCM_USE_CK_MXFP4_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
@ -701,6 +702,11 @@ class rocm_aiter_ops:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._FMOE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_mxfp4_aiter_moe(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._CK_MXFP4_MOE
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
|
||||
|
||||
13
vllm/envs.py
13
vllm/envs.py
@ -114,6 +114,8 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_CK_MXFP4_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4: bool = True
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
@ -953,6 +955,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_USE_AITER_MOE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_MOE", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter ck mxfp4 moe ops.
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_CK_MXFP4_MOE": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_CK_MXFP4_MOE", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter w4a16 moe ops.
|
||||
# By default is disabled.
|
||||
"VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4", "False").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
# use aiter rms norm op if aiter ops are enabled.
|
||||
"VLLM_ROCM_USE_AITER_RMSNORM": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")
|
||||
|
||||
@ -711,6 +711,32 @@ def int8_w8a16_moe_quant_config(
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
a1_scale: torch.Tensor | None = None,
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
"mxfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
def int4_w4afp8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
|
||||
@ -196,6 +196,178 @@ def triton_kernel_fused_experts(
|
||||
return output_tensor
|
||||
|
||||
|
||||
def triton_kernel_moe_oss_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
assert quant_config is not None
|
||||
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
routing_data, gather_idx, scatter_idx = routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
elif quant_config.use_mxfp4_w4a4:
|
||||
from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
|
||||
|
||||
routing_data, gather_idx, scatter_idx = aiter_routing(
|
||||
gating_output, topk, sm_first=not renormalize
|
||||
)
|
||||
|
||||
return triton_kernel_fused_oss_experts(
|
||||
None,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
routing_data,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
activation=activation,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
unpadded_N_w1=unpadded_N_w1,
|
||||
unpadded_K_w1=unpadded_K_w1,
|
||||
unpadded_N_w2=unpadded_N_w2,
|
||||
unpadded_K_w2=unpadded_K_w2,
|
||||
)
|
||||
|
||||
|
||||
# This is a triton implementation of the fused_experts function
|
||||
def triton_kernel_fused_oss_experts(
|
||||
output_tensor: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1, # Tensor or triton_kernels.Tensor
|
||||
w2, # Tensor or triton_kernels.Tensor
|
||||
routing_data, # RoutingData
|
||||
gather_indx, # GatherIndx
|
||||
scatter_indx, # ScatterIndx
|
||||
activation: str = "silu",
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
swiglu_alpha: float = 1.702,
|
||||
swiglu_limit: float = 7.0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
a1q_scale: torch.Tensor | None = None,
|
||||
unpadded_N_w1=None,
|
||||
unpadded_K_w1=None,
|
||||
unpadded_N_w2=None,
|
||||
unpadded_K_w2=None,
|
||||
) -> torch.Tensor:
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
# type check, uint8 means mxfp4
|
||||
assert hidden_states.dtype == torch.bfloat16
|
||||
assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
|
||||
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
|
||||
|
||||
# Shape check, only check non-mxfp4
|
||||
assert hidden_states.shape[-1] == w1.shape[-2]
|
||||
assert w2.shape[-1] == w1.shape[1]
|
||||
|
||||
E, _, N = w1.shape
|
||||
|
||||
if global_num_experts == -1:
|
||||
global_num_experts = E
|
||||
|
||||
gammas = routing_data.gate_scal if routing_data else None
|
||||
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
act = FusedActivation(
|
||||
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
|
||||
(swiglu_alpha, swiglu_limit),
|
||||
2,
|
||||
)
|
||||
intermediate_cache1 = matmul_ogs(
|
||||
hidden_states,
|
||||
w1,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
precision_config=quant_config.w1_precision,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
fused_activation=act,
|
||||
)
|
||||
intermediate_cache3 = matmul_ogs(
|
||||
intermediate_cache1,
|
||||
w2,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
precision_config=quant_config.w2_precision,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
y=output_tensor,
|
||||
)
|
||||
|
||||
elif quant_config.use_mxfp4_w4a4:
|
||||
from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
|
||||
from aiter.ops.triton.quant_moe import downcast_to_static_fp8
|
||||
|
||||
assert quant_config.w1_precision is not None, (
|
||||
"w1_precision in quant config can't be None"
|
||||
)
|
||||
assert quant_config.w2_precision is not None, (
|
||||
"w2_precision in quant config can't be None"
|
||||
)
|
||||
|
||||
hidden_states = downcast_to_static_fp8(
|
||||
hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
|
||||
)
|
||||
|
||||
intermediate_cache1 = moe_gemm_a8w4(
|
||||
hidden_states,
|
||||
w1.storage.data,
|
||||
None,
|
||||
quant_config.w1_precision.weight_scale.storage.data,
|
||||
quant_config.w1_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
quant_config.w1_bias,
|
||||
routing_data,
|
||||
gather_indx=gather_indx,
|
||||
gammas=gammas if apply_router_weight_on_input else None,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
out_dtype=torch.float8_e4m3fn,
|
||||
apply_swiglu=True,
|
||||
alpha=swiglu_alpha,
|
||||
limit=swiglu_limit,
|
||||
unpadded_N=unpadded_N_w1,
|
||||
unpadded_K=unpadded_K_w1,
|
||||
)
|
||||
|
||||
intermediate_cache3 = moe_gemm_a8w4(
|
||||
intermediate_cache1,
|
||||
w2.storage.data,
|
||||
None,
|
||||
quant_config.w2_precision.weight_scale.storage.data,
|
||||
quant_config.w2_precision.flex_ctx.lhs_data.scale,
|
||||
None,
|
||||
quant_config.w2_bias,
|
||||
routing_data,
|
||||
scatter_indx=scatter_indx,
|
||||
gammas=None if apply_router_weight_on_input else gammas,
|
||||
swizzle_mx_scale="CDNA4_SCALE",
|
||||
unpadded_N=unpadded_N_w2,
|
||||
unpadded_K=unpadded_K_w2,
|
||||
)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def make_routing_data(
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@ -443,9 +615,6 @@ class UnfusedOAITritonExperts(BaseOAITritonExperts):
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
if self.quant_config is None:
|
||||
self.quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
|
||||
|
||||
@ -275,7 +275,9 @@ def maybe_roundup_hidden_size(
|
||||
)
|
||||
|
||||
# we are padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
if quant_config and (
|
||||
quant_config.get_name() == "mxfp4" or quant_config.get_name() == "quark"
|
||||
):
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Backend,
|
||||
get_mxfp4_backend,
|
||||
@ -414,7 +416,10 @@ class FusedMoE(CustomOp):
|
||||
# Expert mapping used in self.load_weights
|
||||
self.expert_mapping = expert_mapping
|
||||
|
||||
self._is_mxfp4 = self.is_mxfp4_quant(quant_config=quant_config)
|
||||
|
||||
# Round up hidden size if needed.
|
||||
unpadded_hidden_size = hidden_size
|
||||
hidden_size = maybe_roundup_hidden_size(
|
||||
hidden_size,
|
||||
moe_in_dtype,
|
||||
@ -635,6 +640,7 @@ class FusedMoE(CustomOp):
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": hidden_size,
|
||||
"unpadded_hidden_size": unpadded_hidden_size,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
@ -1115,16 +1121,42 @@ class FusedMoE(CustomOp):
|
||||
expert_id: int,
|
||||
return_success: bool = False,
|
||||
) -> bool | None:
|
||||
if self.quant_config and self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
if self._is_mxfp4:
|
||||
assert self.quant_config is not None
|
||||
if self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
elif self.quant_config.get_name() == "quark":
|
||||
# When self._is_mxfp4 is true, model_dtype must be gpt_oss
|
||||
expert_data = param.data[expert_id]
|
||||
if "input_scale" in weight_name:
|
||||
assert loaded_weight.numel() == 1
|
||||
expert_data.data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
shard_dim = (
|
||||
0 if shard_id in ("w1", "w3") or "bias" in weight_name else 1
|
||||
)
|
||||
if shard_id == "w2":
|
||||
shard_size = loaded_weight.shape[shard_dim] // self.tp_size
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * self.tp_rank, shard_size
|
||||
)
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[0]
|
||||
expert_data.data[:dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[0]
|
||||
dim2 = loaded_weight.shape[1]
|
||||
expert_data.data[:dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
global_expert_id = expert_id
|
||||
@ -2070,6 +2102,26 @@ class FusedMoE(CustomOp):
|
||||
|
||||
return s
|
||||
|
||||
def is_mxfp4_quant(self, quant_config: QuantizationConfig | None = None) -> bool:
|
||||
if quant_config is None:
|
||||
return False
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name == "mxfp4":
|
||||
return True
|
||||
elif name == "quark":
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_type = getattr(vllm_config.model_config.hf_config, "model_type", None)
|
||||
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||
|
||||
assert isinstance(quant_config, QuarkConfig)
|
||||
# Padding for triton kernel only is enabled when it is gpt_oss
|
||||
return quant_config.is_global_mxfp4 and model_type == "gpt_oss"
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@ -59,6 +59,19 @@ class QuarkConfig(QuantizationConfig):
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
self._is_global_mxfp4()
|
||||
|
||||
def _is_global_mxfp4(self):
|
||||
# Check if it is MXFP4 to determine if pre-padding should be applied.
|
||||
# This must be created during the initialization of moe.
|
||||
global_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("global_quant_config")
|
||||
)
|
||||
weight_quant = global_quant_config.get("weight")
|
||||
input_quant = global_quant_config.get("input_tensors")
|
||||
self.is_global_mxfp4 = self._is_mx_fp4(
|
||||
weight_quant=weight_quant, input_quant=input_quant
|
||||
)
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
@ -277,6 +290,44 @@ class QuarkConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(
|
||||
self, weight_quant: dict[str, Any] | None, input_quant: dict[str, Any] | None
|
||||
) -> bool:
|
||||
# Confirm weights quantized.
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Input and weight dtype needs to be fp4.
|
||||
if weight_quant.get("dtype") != "fp4":
|
||||
logger.debug("Quark model is not in MX-FP4 format: weight dtype not fp4")
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if weight_quant.get("qscheme") != "per_group":
|
||||
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32:
|
||||
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if weight_quant.get("scale_format") != "e8m0":
|
||||
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||
return False
|
||||
|
||||
# Input dtype needs to be one of {'fp4', 'fp6_e2m3', 'fp8_e4m3'}.
|
||||
if input_quant.get("dtype") not in ("fp4", "fp6_e2m3", "fp8_e4m3"):
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: expected input dtype "
|
||||
"to be one of {'fp4', 'fp6_e2m3', 'fp8_e4m3'}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_ocp_mx(
|
||||
self,
|
||||
weight_quant: dict[str, Any] | None,
|
||||
|
||||
@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
@ -18,12 +18,14 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
mxfp4_w4a4_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
@ -37,10 +39,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
|
||||
__all__ = [
|
||||
"QuarkMoEMethod",
|
||||
"QuarkW8A8Fp8MoEMethod",
|
||||
"QuarkW4MXFp4MoEMethod_OSS",
|
||||
"QuarkW4MXFp4MoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod",
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
@ -64,8 +73,16 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_type = getattr(vllm_config.model_config.hf_config, "model_type", None)
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_mx_fp4(weight_config, input_config) and model_type == "gpt_oss":
|
||||
return QuarkW4MXFp4MoEMethod_OSS(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
elif quant_config._is_ocp_mx(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
else:
|
||||
@ -620,3 +637,599 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class QuarkW4MXFp4MoEMethodBase(QuarkMoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
self.weight_qscheme = self.weight_quant.get("qscheme")
|
||||
self.input_qscheme = self.input_quant.get("qscheme")
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
def create_common_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
weight_scale: torch.dtype,
|
||||
weight_scale_dtype: torch.dtype,
|
||||
weight_scale_block_size: int,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# WEIGHTS
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 2,
|
||||
dtype=weight_scale,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // weight_scale_block_size,
|
||||
dtype=weight_scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=weight_scale,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // weight_scale_block_size,
|
||||
dtype=weight_scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
|
||||
class QuarkW4MXFp4MoEMethod(QuarkW4MXFp4MoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(weight_config, input_config, moe)
|
||||
if not (
|
||||
self.weight_qscheme == "per_group" and self.input_qscheme == "per_group"
|
||||
):
|
||||
raise ValueError(
|
||||
"For MX(FP4) Fused MoE layers, only per-group scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{self.weight_qscheme}, {self.input_qscheme}"
|
||||
) # noqa E501
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkW4MXFp4MoEMethod with static input scales is currently "
|
||||
"not implemented. Please open an issue."
|
||||
)
|
||||
|
||||
self.emulate = not current_platform.supports_mx() or not (
|
||||
rocm_aiter_ops.is_mxfp4_aiter_moe()
|
||||
)
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={rocm_aiter_ops.is_mxfp4_aiter_moe()}, "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
else:
|
||||
logger.info_once("The current mode supports native MoE MXFP4 computation")
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||
)
|
||||
self.create_common_weights(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
torch.uint8,
|
||||
torch.uint8,
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.emulate:
|
||||
return
|
||||
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
# Pre-shuffle weight scales
|
||||
s0, s1, _ = layer.w13_weight_scale.shape
|
||||
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
||||
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
||||
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
||||
|
||||
s0, s1, _ = layer.w2_weight_scale.shape
|
||||
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=layer.w1_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4MXFp4MoEMethod` yet."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = layer.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
if not self.emulate:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
aiter_acts = {
|
||||
ActivationType.No.name.lower(): ActivationType.No,
|
||||
ActivationType.Silu.name.lower(): ActivationType.Silu,
|
||||
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
|
||||
}
|
||||
assert layer.activation in aiter_acts, (
|
||||
f"Aiter CK fp4 MoE doesn't support activation {layer.activation}"
|
||||
)
|
||||
if hasattr(torch, "float4_e2m1fn_x2"):
|
||||
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
||||
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
||||
else:
|
||||
w13_weight = layer.w13_weight
|
||||
w2_weight = layer.w2_weight
|
||||
|
||||
out = fused_moe(
|
||||
x,
|
||||
w13_weight,
|
||||
w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=aiter_acts[layer.activation],
|
||||
doweight_stage1=False,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class QuarkW4MXFp4MoEMethod_OSS(QuarkW4MXFp4MoEMethodBase):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(weight_config, input_config, moe)
|
||||
|
||||
if not (self.weight_qscheme == "per_group"):
|
||||
raise ValueError(
|
||||
"For MX(FP4) Fused MoE layers, only per-group scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{self.weight_qscheme}, {self.input_qscheme}"
|
||||
) # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.num_experts = num_experts
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
||||
)
|
||||
mxfp4_block = 32
|
||||
weight_dtype = torch.uint8
|
||||
weight_scale_dtype = torch.uint8
|
||||
per_tensor_fp8_act_scale_dtype = torch.bfloat16
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256
|
||||
) # 2880 -> 2944
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 64
|
||||
)
|
||||
|
||||
self.unpadded_hidden_size = extra_weight_attrs.get(
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
self.hidden_pad = hidden_size - self.unpadded_hidden_size
|
||||
self.intermediate_pad = (
|
||||
intermediate_size_per_partition_after_pad - intermediate_size_per_partition
|
||||
)
|
||||
|
||||
self.create_common_weights(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad,
|
||||
weight_dtype,
|
||||
weight_scale_dtype,
|
||||
mxfp4_block,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
dtype=per_tensor_fp8_act_scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
dtype=per_tensor_fp8_act_scale_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4:
|
||||
from aiter.ops.shuffle import shuffle_scale_a16w4, shuffle_weight_a16w4
|
||||
|
||||
w13_aiter_weight = layer.w13_weight.contiguous()
|
||||
w13_aiter_scale = layer.w13_weight_scale.contiguous()
|
||||
w2_aiter_weight = layer.w2_weight.contiguous()
|
||||
w2_aiter_scale = layer.w2_weight_scale.contiguous()
|
||||
|
||||
e, n, k = w13_aiter_weight.shape
|
||||
w13_aiter_weight = (
|
||||
w13_aiter_weight.view(e, n // 2, 2, k)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, k)
|
||||
)
|
||||
w13_aiter_scale = (
|
||||
w13_aiter_scale.view(e, n // 2, 2, -1)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.view(e, n, -1)
|
||||
)
|
||||
|
||||
w13_aiter_weight = w13_aiter_weight.view(torch.float4_e2m1fn_x2)
|
||||
w13_aiter_scale = w13_aiter_scale.view(-1, w13_aiter_scale.shape[-1])
|
||||
w2_aiter_weight = w2_aiter_weight.view(torch.float4_e2m1fn_x2)
|
||||
w2_aiter_scale = w2_aiter_scale.view(-1, w2_aiter_scale.shape[-1])
|
||||
|
||||
self.w13_weight_aiter_tensor = shuffle_weight_a16w4(
|
||||
w13_aiter_weight, 16, True
|
||||
)
|
||||
self.w13_scale_aiter_tensor = shuffle_scale_a16w4(
|
||||
w13_aiter_scale, self.num_experts, True
|
||||
)
|
||||
self.w2_weight_aiter_tensor = shuffle_weight_a16w4(
|
||||
w2_aiter_weight, 16, False
|
||||
)
|
||||
self.w2_scale_aiter_tensor = shuffle_scale_a16w4(
|
||||
w2_aiter_scale, self.num_experts, False
|
||||
)
|
||||
self.w13_bias_aiter_tensor = (
|
||||
layer.w13_bias.view(-1, n // 2, 2)
|
||||
.permute(0, 2, 1)
|
||||
.contiguous()
|
||||
.view(-1, n)
|
||||
)
|
||||
else:
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
# layer.w13_input_scale = torch.nn.Parameter(
|
||||
# layer.w13_input_scale.max(), requires_grad=False)
|
||||
# layer.w2_input_scale = torch.nn.Parameter(
|
||||
# layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
|
||||
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
|
||||
)
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
if envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4:
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=self.w13_scale_aiter_tensor,
|
||||
w2_scale=self.w2_scale_aiter_tensor,
|
||||
)
|
||||
else:
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
|
||||
# TODO: how to set scale?
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
|
||||
)
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER_FUSED_MOE_A16W4:
|
||||
from aiter import moe_cktile2stages_gemm1, moe_cktile2stages_gemm2
|
||||
from aiter.fused_moe import fused_topk, moe_sorting
|
||||
|
||||
token_num = x.shape[0]
|
||||
BLOCKM = 16 if token_num < 2048 else 32
|
||||
topk_weights, topk_ids = fused_topk(x, router_logits, layer.top_k, True)
|
||||
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_out = (
|
||||
moe_sorting(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
self.num_experts,
|
||||
x.shape[1],
|
||||
torch.bfloat16,
|
||||
BLOCKM,
|
||||
)
|
||||
)
|
||||
_, n1, k1 = self.w13_weight_aiter_tensor.shape
|
||||
_, k2, n2 = self.w2_weight_aiter_tensor.shape
|
||||
D = n2 if k2 == k1 else n2 * 2
|
||||
cktile_moe_out1 = torch.empty(
|
||||
(token_num, layer.top_k, D), dtype=torch.bfloat16, device=x.device
|
||||
)
|
||||
moe_cktile2stages_gemm1(
|
||||
x,
|
||||
self.w13_weight_aiter_tensor,
|
||||
cktile_moe_out1,
|
||||
sorted_ids,
|
||||
sorted_expert_ids,
|
||||
num_valid_ids,
|
||||
layer.top_k,
|
||||
self.intermediate_pad // 64 * 64 * 2,
|
||||
self.hidden_pad // 128 * 128, # k_pad_zeros
|
||||
None, # sorted_weights
|
||||
None,
|
||||
self.w13_scale_aiter_tensor,
|
||||
self.w13_bias_aiter_tensor,
|
||||
BLOCKM, # block_size
|
||||
)
|
||||
moe_cktile2stages_gemm2(
|
||||
cktile_moe_out1,
|
||||
self.w2_weight_aiter_tensor,
|
||||
moe_out,
|
||||
sorted_ids,
|
||||
sorted_expert_ids,
|
||||
num_valid_ids,
|
||||
layer.top_k,
|
||||
self.hidden_pad // 64 * 64, # n_pad_zeros
|
||||
self.intermediate_pad // 128 * 128,
|
||||
sorted_weights, # sorted_weights
|
||||
None,
|
||||
self.w2_scale_aiter_tensor,
|
||||
layer.w2_bias,
|
||||
BLOCKM, # block_size
|
||||
)
|
||||
return moe_out
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_oss_forward,
|
||||
)
|
||||
|
||||
return triton_kernel_moe_oss_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
unpadded_N_w1=self.intermediate_size_per_partition * 2,
|
||||
unpadded_K_w1=self.unpadded_hidden_size,
|
||||
unpadded_N_w2=self.unpadded_hidden_size,
|
||||
unpadded_K_w2=self.intermediate_size_per_partition,
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
@ -520,6 +521,287 @@ class GptOssModel(nn.Module):
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _load_weights_quark(
|
||||
self,
|
||||
ep_rank_start: int,
|
||||
ep_rank_end: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
mxfp4_block = 32
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
assert not use_ep, "Expert parallelism is not support for quark MoE"
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
if "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = loaded_weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
# mapping to convert individual experts input_scale into fused_moe.
|
||||
elif "input_scale" in name: # w2 w13 input_scale
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
elif ".w13_weight_scale" in name:
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
# Handle MLP gate and up projection weights scale
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight[
|
||||
2 * tp_rank_start : 2 * tp_rank_end, ...
|
||||
]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
elif ".w2_weight_scale" in name:
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight[
|
||||
..., tp_rank_start // mxfp4_block : tp_rank_end // mxfp4_block
|
||||
]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
# mapping to convert weight and bias of individual
|
||||
# experts gate_up_proj into fused_moe.
|
||||
elif ".w13_weight" in name:
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
|
||||
# Extract gate and up projection parts
|
||||
# since the weight is shuffled, we can slice directly
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight[
|
||||
2 * tp_rank_start : 2 * tp_rank_end, ...
|
||||
]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
elif ".w2_weight" in name:
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight[
|
||||
..., tp_rank_start // 2 : tp_rank_end // 2
|
||||
]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
elif ".w13_bias" in name:
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
if use_ep:
|
||||
narrow_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = loaded_weight[2 * tp_rank_start : 2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(
|
||||
param,
|
||||
narrow_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
elif ".w2_bias" in name:
|
||||
parts = name.split(".")
|
||||
expert_id = int(parts[-2])
|
||||
name = ".".join(parts[:-2] + parts[-1:])
|
||||
# Handle MLP down projection bias
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
if use_ep:
|
||||
loaded_weight = loaded_weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
loaded_weight.zero_()
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name=name,
|
||||
shard_id=None,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(
|
||||
typing.Callable[..., bool], param.weight_loader
|
||||
)
|
||||
success = weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_local_experts,
|
||||
)
|
||||
|
||||
def _load_weights_other(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
@ -669,6 +951,15 @@ class GptOssModel(nn.Module):
|
||||
weights,
|
||||
stacked_params_mapping,
|
||||
)
|
||||
elif quant_method == "quark":
|
||||
return self._load_weights_quark(
|
||||
ep_rank_end,
|
||||
ep_rank_start,
|
||||
heads_per_rank,
|
||||
head_start,
|
||||
weights,
|
||||
stacked_params_mapping,
|
||||
)
|
||||
else:
|
||||
return self._load_weights_other(
|
||||
ep_rank_end,
|
||||
@ -701,6 +992,15 @@ class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3, SupportsLoRA):
|
||||
# MoE Bias
|
||||
".gate_up_proj_bias": ".w13_bias",
|
||||
".down_proj_bias": ".w2_bias",
|
||||
# For quark format
|
||||
".gate_up_proj.weight": ".w13_weight",
|
||||
".gate_up_proj.weight_scale": ".w13_weight_scale",
|
||||
".gate_up_proj.bias": ".w13_bias",
|
||||
".gate_up_proj.input_scale": ".w13_input_scale",
|
||||
".down_proj.weight": ".w2_weight",
|
||||
".down_proj.weight_scale": ".w2_weight_scale",
|
||||
".down_proj.bias": ".w2_bias",
|
||||
".down_proj.input_scale": ".w2_input_scale",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user