mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:05:28 +08:00
719 lines
27 KiB
Python
719 lines
27 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import functools
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm import envs
|
|
from vllm.config.lora import LoRAConfig
|
|
from vllm.distributed.parallel_state import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.distributed.utils import divide
|
|
from vllm.lora.layers.base import BaseLayerWithLoRA
|
|
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
_get_config_dtype_str,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
|
modular_marlin_fused_moe,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
|
modular_triton_fused_moe,
|
|
try_get_optimal_moe_config,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
|
FusedMoEModularMethod,
|
|
)
|
|
|
|
from .utils import _get_lora_device
|
|
|
|
|
|
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|
def __init__(self, base_layer: FusedMoE) -> None:
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
|
|
assert not self.base_layer.use_ep, (
|
|
"EP support for Fused MoE LoRA is not implemented yet."
|
|
)
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.device = _get_lora_device(base_layer)
|
|
self._w13_slices = 2
|
|
self._inject_lora_into_fused_moe()
|
|
|
|
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
|
normalized_config = {}
|
|
for key, value in config.items():
|
|
if key.islower():
|
|
if key.startswith("block_"):
|
|
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
|
|
else:
|
|
normalized_key = key.upper()
|
|
else:
|
|
normalized_key = key
|
|
normalized_config[normalized_key] = value
|
|
return normalized_config
|
|
|
|
def _get_lora_moe_configs(
|
|
self,
|
|
op_prefix: str,
|
|
num_loras: int,
|
|
rank: int,
|
|
num_slices: int,
|
|
M: int,
|
|
layer: FusedMoE,
|
|
top_k: int,
|
|
config_dtype: str,
|
|
):
|
|
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
|
hidden_size = layer.hidden_size
|
|
intermediate_size = layer.intermediate_size_per_partition
|
|
shrink_config = get_lora_op_configs(
|
|
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
|
max_loras=num_loras,
|
|
batch=M,
|
|
hidden_size=hidden_size,
|
|
rank=rank,
|
|
num_slices=num_slices,
|
|
moe_intermediate_size=intermediate_size,
|
|
)
|
|
expand_config = get_lora_op_configs(
|
|
op_type=f"fused_moe_lora_{op_prefix}_expand",
|
|
max_loras=num_loras,
|
|
batch=M,
|
|
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
|
|
rank=rank,
|
|
num_slices=num_slices,
|
|
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
|
|
)
|
|
else: # fall back to the default config
|
|
get_config_func = functools.partial(
|
|
try_get_optimal_moe_config,
|
|
layer.w13_weight.size(),
|
|
layer.w2_weight.size(),
|
|
top_k,
|
|
config_dtype,
|
|
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
|
)
|
|
shrink_config = get_config_func(M)
|
|
expand_config = get_config_func(M)
|
|
shrink_config = self._normalize_keys(shrink_config)
|
|
expand_config = self._normalize_keys(expand_config)
|
|
return shrink_config, expand_config
|
|
|
|
def _inject_lora_into_fused_moe(self):
|
|
moe_state_dict = {}
|
|
top_k = self.base_layer.top_k
|
|
|
|
self.base_layer.ensure_moe_quant_config_init()
|
|
quant_config = self.base_layer.quant_method.moe_quant_config
|
|
|
|
m_fused_moe_fn = (
|
|
modular_triton_fused_moe(
|
|
quant_config, shared_experts=self.base_layer.shared_experts
|
|
)
|
|
if not quant_config.use_mxfp4_w4a16
|
|
else modular_marlin_fused_moe(
|
|
quant_config, shared_experts=self.base_layer.shared_experts
|
|
)
|
|
)
|
|
|
|
def fwd_decorator(layer, func):
|
|
def wrapper(*args, **kwargs):
|
|
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
|
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
|
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
|
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
|
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
|
"apply_router_weight_on_input"
|
|
]
|
|
result = func(*args, **kwargs)
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
def act_decorator(layer, func):
|
|
def wrapper(*args, **kwargs):
|
|
_, output, input = args
|
|
|
|
hidden_states = moe_state_dict["hidden_states"]
|
|
topk_weights = moe_state_dict["topk_weights"]
|
|
curr_topk_ids = moe_state_dict["topk_ids"]
|
|
|
|
expert_map = moe_state_dict["expert_map"]
|
|
|
|
config_dtype = _get_config_dtype_str(
|
|
dtype=hidden_states.dtype,
|
|
use_fp8_w8a8=False,
|
|
use_int8_w8a16=False,
|
|
use_int4_w4a16=False,
|
|
)
|
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
num_tokens = hidden_states.size(0)
|
|
M = min(num_tokens, CHUNK_SIZE)
|
|
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
|
|
shrink_config, expand_config = self._get_lora_moe_configs(
|
|
op_prefix="w13",
|
|
num_loras=self.max_loras,
|
|
rank=max_lora_rank,
|
|
num_slices=self._w13_slices,
|
|
M=M,
|
|
layer=layer,
|
|
top_k=top_k,
|
|
config_dtype=config_dtype,
|
|
)
|
|
|
|
# get the block size of m from customized config or default config
|
|
(
|
|
sorted_token_ids_lora,
|
|
expert_ids_lora,
|
|
num_tokens_post_padded_lora,
|
|
) = self.punica_wrapper.moe_lora_align_block_size(
|
|
curr_topk_ids,
|
|
num_tokens,
|
|
shrink_config["BLOCK_SIZE_M"],
|
|
self.base_layer.local_num_experts,
|
|
self.max_loras,
|
|
self.adapter_enabled,
|
|
expert_map,
|
|
)
|
|
|
|
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
|
|
moe_state_dict["expert_ids_lora"] = expert_ids_lora
|
|
moe_state_dict["num_tokens_post_padded_lora"] = (
|
|
num_tokens_post_padded_lora
|
|
)
|
|
|
|
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
|
|
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
|
|
#
|
|
|
|
self.punica_wrapper.add_lora_fused_moe(
|
|
input.view(-1, top_k, input.shape[-1]),
|
|
hidden_states,
|
|
self.w13_lora_a_stacked,
|
|
self.w13_lora_b_stacked,
|
|
topk_weights,
|
|
sorted_token_ids_lora,
|
|
expert_ids_lora,
|
|
num_tokens_post_padded_lora,
|
|
max_lora_rank,
|
|
top_k,
|
|
shrink_config, ## pass the shrink config
|
|
expand_config, ## pass the expand config
|
|
self.adapter_enabled,
|
|
fully_sharded=self.fully_sharded,
|
|
)
|
|
|
|
result = func(*args, **kwargs)
|
|
|
|
moe_state_dict["intermediate_cache2"] = output
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
def moe_sum_decorator(layer, func):
|
|
def wrapper(*args, **kwargs):
|
|
hidden_states = moe_state_dict["hidden_states"]
|
|
topk_weights = moe_state_dict["topk_weights"]
|
|
|
|
config_dtype = _get_config_dtype_str(
|
|
dtype=hidden_states.dtype,
|
|
use_fp8_w8a8=False,
|
|
use_int8_w8a16=False,
|
|
use_int4_w4a16=False,
|
|
)
|
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
|
num_tokens = hidden_states.size(0)
|
|
M = min(num_tokens, CHUNK_SIZE)
|
|
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
|
|
shrink_config, expand_config = self._get_lora_moe_configs(
|
|
op_prefix="w2",
|
|
num_loras=self.max_loras,
|
|
rank=max_lora_rank,
|
|
num_slices=1,
|
|
M=M,
|
|
layer=layer,
|
|
top_k=top_k,
|
|
config_dtype=config_dtype,
|
|
)
|
|
|
|
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
|
|
expert_ids_lora = moe_state_dict["expert_ids_lora"]
|
|
num_tokens_post_padded_lora = moe_state_dict[
|
|
"num_tokens_post_padded_lora"
|
|
]
|
|
|
|
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
|
|
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
|
|
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
|
intermediate_cache3 = args[0]
|
|
|
|
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
|
|
|
|
self.punica_wrapper.add_lora_fused_moe(
|
|
intermediate_cache3,
|
|
intermediate_cache2,
|
|
self.w2_lora_a_stacked,
|
|
self.w2_lora_b_stacked,
|
|
topk_weights,
|
|
sorted_token_ids_lora,
|
|
expert_ids_lora,
|
|
num_tokens_post_padded_lora,
|
|
max_lora_rank,
|
|
top_k,
|
|
shrink_config, ## pass the shrink config
|
|
expand_config, ## pass the expand config
|
|
self.adapter_enabled,
|
|
True,
|
|
fully_sharded=self.fully_sharded,
|
|
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
|
|
)
|
|
|
|
result = func(*args, **kwargs)
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
fused_experts = m_fused_moe_fn.fused_experts
|
|
|
|
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
|
|
fused_experts.activation = act_decorator(
|
|
self.base_layer, fused_experts.activation
|
|
)
|
|
fused_experts.moe_sum = moe_sum_decorator(
|
|
self.base_layer, fused_experts.moe_sum
|
|
)
|
|
self.base_layer.quant_method = FusedMoEModularMethod(
|
|
self.base_layer.quant_method, m_fused_moe_fn
|
|
)
|
|
|
|
def _create_lora_a_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
):
|
|
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
|
|
torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.local_num_experts,
|
|
lora_config.max_lora_rank
|
|
if not self.fully_sharded
|
|
else divide(lora_config.max_lora_rank, self.tp_size),
|
|
self.base_layer.hidden_size,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self._w13_slices)
|
|
)
|
|
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
|
|
torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.local_num_experts,
|
|
lora_config.max_lora_rank,
|
|
self.base_layer.intermediate_size_per_partition,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
),
|
|
)
|
|
|
|
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
|
|
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
|
|
torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.local_num_experts,
|
|
self.base_layer.intermediate_size_per_partition,
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self._w13_slices)
|
|
)
|
|
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
|
|
torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.local_num_experts,
|
|
self.base_layer.hidden_size
|
|
if not self.fully_sharded
|
|
else divide(self.base_layer.hidden_size, self.tp_size),
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
),
|
|
)
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: PretrainedConfig | None = None,
|
|
) -> None:
|
|
"""Initializes lora matrices."""
|
|
self.max_loras = lora_config.max_loras
|
|
self.fully_sharded = lora_config.fully_sharded_loras
|
|
|
|
self.adapter_enabled = torch.tensor(
|
|
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
|
)
|
|
|
|
self._create_lora_a_weights(max_loras, lora_config)
|
|
self._create_lora_b_weights(max_loras, lora_config)
|
|
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
|
# to create a dummy LoRA weights.
|
|
# TODO Optimize this section
|
|
self.lora_a_stacked = []
|
|
self.lora_b_stacked = []
|
|
for lora_id in range(max_loras):
|
|
for experts_id in range(self.base_layer.local_num_experts):
|
|
# gate_proj,down_proj,up_proj
|
|
self.lora_a_stacked.append(
|
|
self.w13_lora_a_stacked[0][lora_id][experts_id]
|
|
)
|
|
self.lora_a_stacked.append(
|
|
self.w2_lora_a_stacked[0][lora_id][experts_id]
|
|
)
|
|
|
|
self.lora_b_stacked.append(
|
|
self.w13_lora_b_stacked[0][lora_id][experts_id]
|
|
)
|
|
self.lora_b_stacked.append(
|
|
self.w2_lora_b_stacked[0][lora_id][experts_id]
|
|
)
|
|
|
|
self.lora_a_stacked.append(
|
|
self.w13_lora_a_stacked[1][lora_id][experts_id]
|
|
)
|
|
self.lora_b_stacked.append(
|
|
self.w13_lora_b_stacked[1][lora_id][experts_id]
|
|
)
|
|
|
|
def reset_lora(self, index: int):
|
|
"""Resets the lora weights at index back to 0."""
|
|
for pos in range(self._w13_slices):
|
|
self.w13_lora_a_stacked[pos][index] = 0
|
|
self.w13_lora_b_stacked[pos][index] = 0
|
|
|
|
self.w2_lora_a_stacked[0][index] = 0
|
|
self.w2_lora_b_stacked[0][index] = 0
|
|
self.adapter_enabled[index] = 0
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor | list[torch.Tensor],
|
|
lora_b: torch.Tensor | list[torch.Tensor],
|
|
):
|
|
"""Overwrites lora tensors at index."""
|
|
assert isinstance(lora_a, list)
|
|
assert isinstance(lora_b, list)
|
|
self.reset_lora(index)
|
|
self.adapter_enabled[index] = 1
|
|
for eid in range(len(lora_a) // 3):
|
|
w1_lora_a = lora_a[eid * 3]
|
|
w2_lora_a = lora_a[eid * 3 + 1]
|
|
w3_lora_a = lora_a[eid * 3 + 2]
|
|
w1_lora_b = lora_b[eid * 3]
|
|
w2_lora_b = lora_b[eid * 3 + 1]
|
|
w3_lora_b = lora_b[eid * 3 + 2]
|
|
|
|
# Handle the case of adding LoRA to only a subset of experts
|
|
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
|
|
continue
|
|
|
|
if self.tp_size > 1:
|
|
shard_size = self.base_layer.intermediate_size_per_partition
|
|
start_idx = self.tp_rank * shard_size
|
|
end_idx = (self.tp_rank + 1) * shard_size
|
|
|
|
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
|
|
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
|
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
|
|
|
if self.fully_sharded:
|
|
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
|
|
# and W2 B along the hidden_size dim.
|
|
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
|
|
w13_start_idx = self.tp_rank * w13_shard_size
|
|
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
|
|
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
|
w3_lora_a = w3_lora_a[w13_start_idx:w13_end_idx, :]
|
|
|
|
w2_shard_size = self.w2_lora_b_stacked[0][index, eid].shape[0]
|
|
w2_start_idx = self.tp_rank * w2_shard_size
|
|
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
|
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
|
# w1 lora_a
|
|
self.w13_lora_a_stacked[0][
|
|
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
|
].copy_(w1_lora_a, non_blocking=True)
|
|
# w3 lora_a
|
|
self.w13_lora_a_stacked[1][
|
|
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
|
|
].copy_(w3_lora_a, non_blocking=True)
|
|
|
|
# w1 lora_b
|
|
self.w13_lora_b_stacked[0][
|
|
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
|
|
].copy_(w1_lora_b, non_blocking=True)
|
|
# w3 lora_b
|
|
self.w13_lora_b_stacked[1][
|
|
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
|
].copy_(w3_lora_b, non_blocking=True)
|
|
|
|
self.w2_lora_a_stacked[0][
|
|
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
|
].copy_(w2_lora_a, non_blocking=True)
|
|
|
|
self.w2_lora_b_stacked[0][
|
|
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
|
].copy_(w2_lora_b, non_blocking=True)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.base_layer.forward(*args, **kwargs)
|
|
|
|
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
|
|
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
|
|
|
|
@property
|
|
def _shared_experts(self):
|
|
return self.base_layer._shared_experts
|
|
|
|
@property
|
|
def quant_method(self):
|
|
return self.base_layer.quant_method
|
|
|
|
@property
|
|
def is_internal_router(self) -> bool:
|
|
return self.base_layer.is_internal_router
|
|
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
"""Returns True if the layer can be replaced by this LoRA layer."""
|
|
# return type(source_layer) is FusedMoE
|
|
|
|
return type(source_layer) is FusedMoE and len(packed_modules_list) == 2
|
|
|
|
|
|
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
|
def __init__(self, base_layer):
|
|
super().__init__(base_layer)
|
|
self._w13_slices = 1
|
|
|
|
def _create_lora_b_weights(self, max_loras, lora_config):
|
|
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
|
|
torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.local_num_experts,
|
|
self.base_layer.intermediate_size_per_partition * 2,
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self._w13_slices)
|
|
)
|
|
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
|
|
torch.zeros(
|
|
(
|
|
max_loras,
|
|
self.base_layer.local_num_experts,
|
|
self.base_layer.hidden_size
|
|
if not self.fully_sharded
|
|
else divide(self.base_layer.hidden_size, self.tp_size),
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
),
|
|
)
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: PretrainedConfig | None = None,
|
|
) -> None:
|
|
"""Initializes lora matrices."""
|
|
self.max_loras = lora_config.max_loras
|
|
self.fully_sharded = lora_config.fully_sharded_loras
|
|
|
|
self.adapter_enabled = torch.tensor(
|
|
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
|
)
|
|
|
|
self._create_lora_a_weights(max_loras, lora_config)
|
|
self._create_lora_b_weights(max_loras, lora_config)
|
|
|
|
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
|
|
if self.tp_size == 1 or not self.fully_sharded:
|
|
return w13_lora_a
|
|
|
|
# w13_lora_a shape (num_experts,rank,input_size)
|
|
current_lora_rank = w13_lora_a.shape[1]
|
|
assert current_lora_rank % self.tp_size == 0
|
|
|
|
sliced_rank = current_lora_rank // self.tp_size
|
|
start_idx = self.tp_rank * sliced_rank
|
|
end_idx = (self.tp_rank + 1) * sliced_rank
|
|
return w13_lora_a[:, start_idx:end_idx, :]
|
|
|
|
def _slice_w13_b(self, w13_lora_b: torch.Tensor, is_interleave: bool = True):
|
|
if self.tp_size == 1:
|
|
return w13_lora_b
|
|
|
|
# w13_lora_b shape (num_experts,output_size,rank)
|
|
shard_size = self.base_layer.intermediate_size_per_partition
|
|
start_idx = self.tp_rank * shard_size
|
|
end_idx = (self.tp_rank + 1) * shard_size
|
|
if is_interleave:
|
|
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
|
|
# in the interleaved order, and corresponding LoRA need to be processed.
|
|
w1_lora_b = w13_lora_b[:, ::2, :]
|
|
w3_lora_b = w13_lora_b[:, 1::2, :]
|
|
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
|
|
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
|
|
|
|
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
|
|
1, 2
|
|
)
|
|
else:
|
|
slice_size = w13_lora_b.shape[1] // 2
|
|
w1_lora_b = w13_lora_b[:, :slice_size, :]
|
|
w3_lora_b = w13_lora_b[:, slice_size:, :]
|
|
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
|
|
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
|
|
|
|
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
|
|
|
|
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
|
|
if self.tp_size == 1:
|
|
return w2_lora_a
|
|
# w2_lora_a shape (num_experts,rank,input_size)
|
|
shard_size = self.base_layer.intermediate_size_per_partition
|
|
start_idx = self.tp_rank * shard_size
|
|
end_idx = (self.tp_rank + 1) * shard_size
|
|
|
|
return w2_lora_a[:, :, start_idx:end_idx]
|
|
|
|
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
|
|
if self.tp_size == 1 or not self.fully_sharded:
|
|
return w2_lora_b
|
|
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
|
|
# w2_lora_b shape (num_experts,output_size,rank)
|
|
current_lora_size = w2_lora_b.shape[1]
|
|
|
|
sliced_size = current_lora_size // self.tp_size
|
|
start_idx = self.tp_rank * sliced_size
|
|
end_idx = (self.tp_rank + 1) * sliced_size
|
|
return w2_lora_b[:, start_idx:end_idx, :]
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor | list[torch.Tensor],
|
|
lora_b: torch.Tensor | list[torch.Tensor],
|
|
):
|
|
"""Overwrites lora tensors at index."""
|
|
# Make mypy happy
|
|
assert isinstance(lora_a, list)
|
|
assert isinstance(lora_b, list)
|
|
assert len(lora_a) == len(lora_b) == 2
|
|
|
|
self.reset_lora(index)
|
|
self.adapter_enabled[index] = 1
|
|
|
|
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
|
w13_lora_a, w2_lora_a = lora_a
|
|
w13_lora_b, w2_lora_b = lora_b
|
|
|
|
# (num_experts,rank,input_size)
|
|
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
|
|
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
|
|
# (output_size,num_experts,rank)
|
|
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
|
|
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
|
|
# (num_experts,output_size,rank)
|
|
w13_lora_b = w13_lora_b.permute(1, 0, 2)
|
|
w2_lora_b = w2_lora_b.permute(1, 0, 2)
|
|
|
|
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
|
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b, is_interleave=True)
|
|
|
|
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
|
|
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
|
|
|
|
self.w13_lora_a_stacked[0][
|
|
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
|
|
].copy_(sliced_w13_lora_a, non_blocking=True)
|
|
self.w2_lora_a_stacked[0][
|
|
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
|
|
].copy_(sliced_w2_lora_a, non_blocking=True)
|
|
|
|
self.w13_lora_b_stacked[0][
|
|
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
|
|
].copy_(sliced_w13_lora_b, non_blocking=True)
|
|
self.w2_lora_b_stacked[0][
|
|
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
|
|
].copy_(sliced_w2_lora_b, non_blocking=True)
|
|
|
|
@property
|
|
def w13_input_size(self):
|
|
"""
|
|
Full size
|
|
"""
|
|
return self.w13_lora_a_stacked[0].shape[-1]
|
|
|
|
@property
|
|
def w13_output_size(self):
|
|
"""
|
|
Full size
|
|
"""
|
|
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
|
|
|
|
@property
|
|
def w2_input_size(self):
|
|
"""
|
|
Full size
|
|
"""
|
|
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
|
|
|
|
@property
|
|
def w2_output_size(self):
|
|
"""
|
|
Full size
|
|
"""
|
|
return self.w2_lora_a_stacked[0].shape[-2]
|
|
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
"""Returns True if the layer can be replaced by this LoRA layer."""
|
|
|
|
return type(source_layer) is FusedMoE and len(packed_modules_list) == 1
|