mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 01:15:44 +08:00
[fix]: add Arm 4bit fused moe support (#23809)
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
a986f17028
commit
6b87ce2ecd
@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/layernorm.cpp"
|
"csrc/cpu/layernorm.cpp"
|
||||||
"csrc/cpu/mla_decode.cpp"
|
"csrc/cpu/mla_decode.cpp"
|
||||||
"csrc/cpu/pos_encoding.cpp"
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
"csrc/cpu/torch_bindings.cpp")
|
"csrc/cpu/torch_bindings.cpp"
|
||||||
|
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||||
|
|
||||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
|
|||||||
@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" int tp_rank, int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
|
|
||||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||||
|
|
||||||
|
ops.def(
|
||||||
|
"dynamic_4bit_int_moe("
|
||||||
|
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
|
||||||
|
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
|
||||||
|
"int group_size, bool apply_router_weight_on_input, int activation_kind"
|
||||||
|
") -> Tensor");
|
||||||
|
|
||||||
|
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
|
||||||
|
|
||||||
// PagedAttention V2.
|
// PagedAttention V2.
|
||||||
ops.def(
|
ops.def(
|
||||||
"paged_attention_v2("
|
"paged_attention_v2("
|
||||||
|
|||||||
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <ATen/Parallel.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
// _dyn_quant_matmul_4bit is only available on AArch64.
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
|
||||||
|
int64_t group_size_eff, int64_t in_features,
|
||||||
|
int64_t out_features) {
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
|
||||||
|
in_features, out_features);
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
|
||||||
|
"_dyn_quant_matmul_4bit is unavailable on this architecture");
|
||||||
|
return {};
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ActivationKind : int64_t {
|
||||||
|
SwiGLU_Gu = 0, // act = SiLU(g) * u
|
||||||
|
SwiGLUOAI = 1, // act = SiLU(u) * g
|
||||||
|
SiLU = 2 // SiLU
|
||||||
|
};
|
||||||
|
|
||||||
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
|
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||||
|
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||||
|
int64_t activation_kind) {
|
||||||
|
TORCH_CHECK(x.dim() == 2, "x must be 2D");
|
||||||
|
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
|
||||||
|
"topk tensors must be [T, K]");
|
||||||
|
TORCH_CHECK(
|
||||||
|
w13_packed.size(0) == w2_packed.size(0),
|
||||||
|
"w13_packed and w2_packed must have same number of experts in dim 0");
|
||||||
|
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
|
||||||
|
|
||||||
|
const int64_t T = x.size(0);
|
||||||
|
const int64_t K = topk_ids.size(1);
|
||||||
|
const int64_t E = w13_packed.size(0);
|
||||||
|
const int64_t N = T * K;
|
||||||
|
|
||||||
|
auto x_c = x.contiguous();
|
||||||
|
auto ids_c = topk_ids.contiguous();
|
||||||
|
auto gates_c = topk_weights.to(at::kFloat).contiguous();
|
||||||
|
|
||||||
|
// bucketing tokens -> experts
|
||||||
|
c10::SmallVector<int64_t, 64> counts(
|
||||||
|
E, 0); // Small vector uses stack allocation
|
||||||
|
{
|
||||||
|
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||||
|
for (int64_t i = 0; i < N; ++i) {
|
||||||
|
const int64_t e_id = ids_ptr[i];
|
||||||
|
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
|
||||||
|
counts[e_id]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
|
||||||
|
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
|
||||||
|
|
||||||
|
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
|
||||||
|
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
|
||||||
|
{
|
||||||
|
c10::SmallVector<int64_t, 64> cursor(E, 0);
|
||||||
|
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||||
|
const auto* gts_ptr = gates_c.data_ptr<float>();
|
||||||
|
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
|
||||||
|
auto* gate_ptr = expert_gates.data_ptr<float>();
|
||||||
|
|
||||||
|
for (int64_t t = 0; t < T; ++t) {
|
||||||
|
const int64_t base = t * K;
|
||||||
|
for (int64_t k = 0; k < K; ++k) {
|
||||||
|
const int64_t idx = base + k;
|
||||||
|
const int64_t e = ids_ptr[idx];
|
||||||
|
const int64_t p = offsets[e] + (cursor[e]++);
|
||||||
|
tok_ptr[p] = t;
|
||||||
|
gate_ptr[p] = gts_ptr[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
|
||||||
|
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
|
||||||
|
|
||||||
|
// Per-expert outputs filled in parallel
|
||||||
|
std::vector<torch::Tensor> y_list(E);
|
||||||
|
y_list.resize(E);
|
||||||
|
|
||||||
|
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
|
||||||
|
for (int64_t e = e_begin; e < e_end; ++e) {
|
||||||
|
const int64_t te = counts[e];
|
||||||
|
if (te == 0) {
|
||||||
|
y_list[e] = at::empty({0, H}, x_c.options());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t start = offsets[e];
|
||||||
|
|
||||||
|
auto sel_tokens =
|
||||||
|
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||||
|
auto gates_e =
|
||||||
|
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||||
|
|
||||||
|
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
|
||||||
|
|
||||||
|
if (apply_router_weight_on_input) {
|
||||||
|
x_e = x_e.mul(gates_e.unsqueeze(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||||
|
auto w2_e = w2_packed.select(/*dim=*/0, e);
|
||||||
|
|
||||||
|
// W13
|
||||||
|
auto y13 =
|
||||||
|
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
|
||||||
|
|
||||||
|
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
|
||||||
|
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
|
||||||
|
|
||||||
|
torch::Tensor act;
|
||||||
|
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
|
||||||
|
constexpr double kAlpha = 1.702; // GPT-OSS default
|
||||||
|
constexpr double kLimit = 7.0; // GPT-OSS default
|
||||||
|
auto gate_c = at::clamp_max(g_part, kLimit);
|
||||||
|
auto up_c = at::clamp(u_part, -kLimit, kLimit);
|
||||||
|
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
|
||||||
|
act = up_c.add(1.0).mul(glu);
|
||||||
|
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
|
||||||
|
act = at::silu(g_part).mul(u_part);
|
||||||
|
}
|
||||||
|
|
||||||
|
// W2
|
||||||
|
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
|
||||||
|
|
||||||
|
if (!apply_router_weight_on_input) {
|
||||||
|
y = y.mul(gates_e.unsqueeze(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store per-expert result
|
||||||
|
y_list[e] = y;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Concatenate all expert outputs to match expert_tokens order
|
||||||
|
auto Y_all = at::cat(y_list, /*dim=*/0);
|
||||||
|
auto out = at::zeros({T, H}, x.options());
|
||||||
|
out =
|
||||||
|
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
|||||||
const std::optional<torch::Tensor>& has_initial_state,
|
const std::optional<torch::Tensor>& has_initial_state,
|
||||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||||
|
|
||||||
|
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||||
|
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||||
|
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||||
|
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||||
|
int64_t activation_kind);
|
||||||
|
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||||
torch::Tensor& rank_data, int64_t rank,
|
torch::Tensor& rank_data, int64_t rank,
|
||||||
|
|||||||
@ -98,13 +98,16 @@ def select_experts(
|
|||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
elif custom_routing_function is None:
|
elif custom_routing_function is None:
|
||||||
assert scoring_func == "softmax"
|
assert scoring_func == "softmax"
|
||||||
topk_weights = torch.nn.functional.softmax(router_logits,
|
topk_logit_vals, topk_idx = torch.topk(router_logits,
|
||||||
dim=1,
|
k=top_k,
|
||||||
dtype=torch.float32)
|
dim=-1,
|
||||||
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
|
sorted=False)
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
|
||||||
return topk_weights, topk_ids.to(torch.int32)
|
else:
|
||||||
|
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
|
||||||
|
topk_vals = (topk_logit_vals - logZ).exp()
|
||||||
|
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
|
||||||
else:
|
else:
|
||||||
return custom_routing_function(hidden_states=hidden_states,
|
return custom_routing_function(hidden_states=hidden_states,
|
||||||
gating_output=router_logits,
|
gating_output=router_logits,
|
||||||
|
|||||||
@ -69,8 +69,6 @@ else:
|
|||||||
if is_rocm_aiter_moe_enabled():
|
if is_rocm_aiter_moe_enabled():
|
||||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||||
rocm_aiter_grouped_topk as grouped_topk)
|
rocm_aiter_grouped_topk as grouped_topk)
|
||||||
elif current_platform.is_cpu():
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||||
if current_platform.is_tpu():
|
if current_platform.is_tpu():
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
|||||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||||
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
||||||
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
||||||
|
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
is_valid_flashinfer_cutlass_fused_moe)
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
@ -63,7 +64,7 @@ __all__ = [
|
|||||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||||
"CompressedTensorsW8A8Int8MoEMethod",
|
"CompressedTensorsW8A8Int8MoEMethod",
|
||||||
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||||
"CompressedTensorsW4A4MoeMethod"
|
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
||||||
layer.moe_config)
|
layer.moe_config)
|
||||||
|
elif quant_config._is_dynamic_token_w4a8_int(weight_quant,
|
||||||
|
input_quant):
|
||||||
|
return CompressedTensorsW4A8Int8MoEMethod(quant_config,
|
||||||
|
layer.moe_config)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||||
@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
quant_config=self.moe_quant_config,
|
quant_config=self.moe_quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
"""
|
||||||
|
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
|
||||||
|
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
|
||||||
|
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
|
||||||
|
- Bias: Same data type as original weights
|
||||||
|
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
|
||||||
|
quantized inside the kernel
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
|
self.has_bias = self.moe.has_bias
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
# Validate scheme: weights=W4 (channel or group),
|
||||||
|
# activations=dynamic TOKEN (A8)
|
||||||
|
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||||
|
aq = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
|
"input_activations")
|
||||||
|
|
||||||
|
# Must be dynamic per-token activations
|
||||||
|
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
|
||||||
|
raise ValueError(
|
||||||
|
"W4A8-int MoE needs dynamic per-token activation quantization."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Weight can be channel-wise (group_size=None) or group-wise
|
||||||
|
self.group_size = wq.group_size if (wq.group_size is not None) else -1
|
||||||
|
if wq.num_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"This method only supports 4-bit weights (num_bits=4).")
|
||||||
|
|
||||||
|
# CPU only
|
||||||
|
if not current_platform.is_cpu():
|
||||||
|
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
|
||||||
|
|
||||||
|
# Arm: check _dyn ops availability
|
||||||
|
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||||
|
try:
|
||||||
|
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||||
|
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||||
|
except AttributeError as err:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
|
||||||
|
install a newer build.""") from err
|
||||||
|
self.static_input_scales = False # always dynamic per token
|
||||||
|
|
||||||
|
# ---- parameter creation ----
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
# Shapes per local rank (TP/EP):
|
||||||
|
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
|
||||||
|
# w2 : [E, H, I_local] int8
|
||||||
|
# Scales:
|
||||||
|
# channel-wise: group_size=-1 -> per-output-row, single scale per row
|
||||||
|
# group-wise : group_size=g ->
|
||||||
|
# per-output-row, (in_features/g) scales
|
||||||
|
|
||||||
|
E = num_experts
|
||||||
|
H = hidden_size
|
||||||
|
IN = intermediate_size_per_partition
|
||||||
|
g = self.group_size
|
||||||
|
|
||||||
|
# Per-row scale columns
|
||||||
|
def _n_scale_cols(in_features: int) -> int:
|
||||||
|
return 1 if g == -1 else (in_features // g)
|
||||||
|
|
||||||
|
# Register unpacked int4-as-int8 weights the loader will fill.
|
||||||
|
w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(w13, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w13_weight", w13)
|
||||||
|
|
||||||
|
w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(w2, extra_weight_attrs)
|
||||||
|
layer.register_parameter("w2_weight", w2)
|
||||||
|
|
||||||
|
# Register scales
|
||||||
|
# KleidiAI groupwise kernels accepts float32 scales
|
||||||
|
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||||
|
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||||
|
|
||||||
|
w13_s = torch.nn.Parameter(torch.ones(E,
|
||||||
|
2 * IN,
|
||||||
|
_n_scale_cols(H),
|
||||||
|
dtype=scale_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(
|
||||||
|
w13_s, {
|
||||||
|
"quant_method": "channel" if g == -1 else "group",
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_s)
|
||||||
|
|
||||||
|
w2_s = torch.nn.Parameter(torch.ones(E,
|
||||||
|
H,
|
||||||
|
_n_scale_cols(IN),
|
||||||
|
dtype=scale_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
set_weight_attrs(
|
||||||
|
w2_s, {
|
||||||
|
"quant_method": "channel" if g == -1 else "group",
|
||||||
|
**extra_weight_attrs
|
||||||
|
})
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_s)
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
w13_bias = torch.nn.Parameter(torch.zeros(E,
|
||||||
|
2 * IN,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_bias", w13_bias)
|
||||||
|
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_bias", w2_bias)
|
||||||
|
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Placeholders for packed weights (will be replaced after packing)
|
||||||
|
layer.register_parameter(
|
||||||
|
"w13_weight_packed",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
|
||||||
|
|
||||||
|
layer.register_parameter(
|
||||||
|
"w2_weight_packed",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
|
||||||
|
|
||||||
|
# dims for 4 bit fused matmuls
|
||||||
|
layer.w13_in_features = H
|
||||||
|
layer.w13_out_features = 2 * IN
|
||||||
|
layer.w2_in_features = IN
|
||||||
|
layer.w2_out_features = H
|
||||||
|
layer.group_size = g
|
||||||
|
|
||||||
|
# post-load packing to dyn-4bit KleidiAI kernel's format
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
E = layer.w13_weight.shape[0]
|
||||||
|
H = layer.w13_in_features
|
||||||
|
I2 = layer.w13_out_features
|
||||||
|
IN = layer.w2_in_features
|
||||||
|
g = layer.group_size
|
||||||
|
|
||||||
|
def _pack_matrix(int4_as_int8_2d: torch.Tensor,
|
||||||
|
scales_2d: torch.Tensor,
|
||||||
|
bias_1d: Optional[torch.Tensor], in_features: int,
|
||||||
|
out_features: int) -> torch.Tensor:
|
||||||
|
# int4 values are stored as int8 in [-8,7].
|
||||||
|
# Shift to unsigned nibble and pack pairs along input-dim.
|
||||||
|
tmp = int4_as_int8_2d.add(8) # [out, in]
|
||||||
|
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
|
||||||
|
torch.uint8) # [out, in//2]
|
||||||
|
|
||||||
|
# KleidiAI groupwise kernels accepts float32 scales
|
||||||
|
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||||
|
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||||
|
scales = scales_2d.to(scale_dtype)
|
||||||
|
bias = None if bias_1d is None else bias_1d.to(torch.float32)
|
||||||
|
return torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||||
|
uint8_nibbles, scales, bias, g if g != -1 else in_features,
|
||||||
|
in_features, out_features)
|
||||||
|
|
||||||
|
# Pack per expert
|
||||||
|
w13_packed_list = []
|
||||||
|
w2_packed_list = []
|
||||||
|
|
||||||
|
has_w13_bias = hasattr(layer,
|
||||||
|
"w13_bias") and layer.w13_bias is not None
|
||||||
|
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
|
||||||
|
|
||||||
|
for e in range(E):
|
||||||
|
w13_packed_list.append(
|
||||||
|
_pack_matrix(
|
||||||
|
layer.w13_weight[e], # [2I, H]
|
||||||
|
layer.w13_weight_scale[e], # [2I, H/g or 1]
|
||||||
|
layer.w13_bias[e] if has_w13_bias else None, # [2I]
|
||||||
|
H,
|
||||||
|
I2))
|
||||||
|
w2_packed_list.append(
|
||||||
|
_pack_matrix(
|
||||||
|
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
|
||||||
|
layer.w2_weight[e], # [H, IN]
|
||||||
|
layer.w2_weight_scale[e], # [H, IN/g or 1]
|
||||||
|
layer.w2_bias[e] if has_w2_bias else None, # [H]
|
||||||
|
IN,
|
||||||
|
layer.w2_out_features # in_features=IN, out_features=H
|
||||||
|
))
|
||||||
|
|
||||||
|
# each packed tensor has identical shape per expert; stack on dim 0
|
||||||
|
w13_packed = torch.stack(w13_packed_list, dim=0)
|
||||||
|
w2_packed = torch.stack(w2_packed_list, dim=0)
|
||||||
|
|
||||||
|
replace_parameter(layer, "w13_weight_packed",
|
||||||
|
torch.nn.Parameter(w13_packed, requires_grad=False))
|
||||||
|
replace_parameter(layer, "w2_weight_packed",
|
||||||
|
torch.nn.Parameter(w2_packed, requires_grad=False))
|
||||||
|
|
||||||
|
# free raw tensors/scales/bias now that they're packed into the payload.
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_weight",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_weight",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_weight_scale",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_weight_scale",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
if has_w13_bias:
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w13_bias",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
if has_w2_bias:
|
||||||
|
replace_parameter(
|
||||||
|
layer, "w2_bias",
|
||||||
|
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||||
|
|
||||||
|
def get_fused_moe_quant_config(
|
||||||
|
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||||
|
# CPU dynamic 4-bit MoE path does not use modular kernels or
|
||||||
|
# fused_experts; quant config is not needed.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||||
|
assert activation in (
|
||||||
|
"silu", "swigluoai",
|
||||||
|
"swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||||
|
assert expert_map is None, """expert_map/EP not implemented
|
||||||
|
for CPU dyn-4bit MoE."""
|
||||||
|
|
||||||
|
def _act_kind(s: str) -> int:
|
||||||
|
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
|
||||||
|
if s == "swiglu":
|
||||||
|
return 0
|
||||||
|
if s == "swigluoai":
|
||||||
|
return 1
|
||||||
|
if s == "silu":
|
||||||
|
return 2
|
||||||
|
raise ValueError(f"Unknown activation '{s}'")
|
||||||
|
|
||||||
|
# Apply topk softmax on router output
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.ops._C.dynamic_4bit_int_moe(
|
||||||
|
x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed,
|
||||||
|
layer.w2_weight_packed, layer.w2_out_features,
|
||||||
|
layer.w2_in_features, layer.w13_out_features, layer.group_size,
|
||||||
|
apply_router_weight_on_input, int(_act_kind(activation)))
|
||||||
Loading…
x
Reference in New Issue
Block a user