[Bugfix] Fix moe weight losing all extra attrs after process_weights_after_loading. (#16854)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu 2025-04-28 16:05:07 -05:00 committed by GitHub
parent cc5befbced
commit ed2462030f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 11 deletions

View File

@ -113,12 +113,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer) super().process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( # Padding the weight for better performance on ROCm
layer.w13_weight.data), layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
requires_grad=False) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
# Lazy import to avoid importing triton. # Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights) is_rocm_aiter_moe_enabled, shuffle_weights)
@ -127,10 +124,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, layer.w13_weight.data = shuffled_w13
requires_grad=False) layer.w2_weight.data = shuffled_w2
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
requires_grad=False)
if current_platform.is_cpu(): if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:

View File

@ -156,7 +156,7 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
input_2d: torch.Tensor, input_2d: torch.Tensor,
output_shape: List) -> torch.Tensor: output_shape: List) -> torch.Tensor:
from vllm.platforms.rocm import on_mi250_mi300 from vllm.platforms.rocm import on_mi250_mi300
if envs.VLLM_ROCM_USE_SKINNY_GEMM and not on_mi250_mi300( if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi250_mi300(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b, output = ops.wvSplitKQ(weight.t(), qinput, out_dtype, scale_a, scale_b,
current_platform.get_cu_count()) current_platform.get_cu_count())