mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:06:15 +08:00
[LoRA] Cleanup FusedMoEWithLoRA (#29187)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
933f67ecd8
commit
1489902b53
@ -42,6 +42,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.device = base_layer.w2_weight.device
|
self.device = base_layer.w2_weight.device
|
||||||
|
self.w13_slices = 2
|
||||||
self._inject_lora_into_fused_moe()
|
self._inject_lora_into_fused_moe()
|
||||||
|
|
||||||
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
||||||
@ -60,8 +61,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
def _get_lora_moe_configs(
|
def _get_lora_moe_configs(
|
||||||
self,
|
self,
|
||||||
op_prefix: str,
|
op_prefix: str,
|
||||||
lora_a_stacked: torch.Tensor,
|
num_loras: int,
|
||||||
lora_b_stacked: torch.Tensor,
|
rank: int,
|
||||||
num_slices: int,
|
num_slices: int,
|
||||||
M: int,
|
M: int,
|
||||||
layer: FusedMoE,
|
layer: FusedMoE,
|
||||||
@ -69,23 +70,25 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
config_dtype: str,
|
config_dtype: str,
|
||||||
):
|
):
|
||||||
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
||||||
|
hidden_size = layer.hidden_size
|
||||||
|
intermediate_size = layer.intermediate_size_per_partition
|
||||||
shrink_config = get_lora_op_configs(
|
shrink_config = get_lora_op_configs(
|
||||||
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
||||||
max_loras=lora_a_stacked.shape[0],
|
max_loras=num_loras,
|
||||||
batch=M,
|
batch=M,
|
||||||
hidden_size=lora_a_stacked.shape[-1],
|
hidden_size=hidden_size,
|
||||||
rank=lora_a_stacked.shape[-2],
|
rank=rank,
|
||||||
num_slices=num_slices,
|
num_slices=num_slices,
|
||||||
moe_intermediate_size=lora_b_stacked.shape[-2],
|
moe_intermediate_size=intermediate_size,
|
||||||
)
|
)
|
||||||
expand_config = get_lora_op_configs(
|
expand_config = get_lora_op_configs(
|
||||||
op_type=f"fused_moe_lora_{op_prefix}_expand",
|
op_type=f"fused_moe_lora_{op_prefix}_expand",
|
||||||
max_loras=lora_a_stacked.shape[0],
|
max_loras=num_loras,
|
||||||
batch=M,
|
batch=M,
|
||||||
hidden_size=lora_a_stacked.shape[-1],
|
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
|
||||||
rank=lora_a_stacked.shape[-2],
|
rank=rank,
|
||||||
num_slices=num_slices,
|
num_slices=num_slices,
|
||||||
moe_intermediate_size=lora_b_stacked.shape[-2],
|
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
|
||||||
)
|
)
|
||||||
else: # fall back to the default config
|
else: # fall back to the default config
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
@ -152,12 +155,12 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
num_tokens = hidden_states.size(0)
|
num_tokens = hidden_states.size(0)
|
||||||
M = min(num_tokens, CHUNK_SIZE)
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
|
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
|
||||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||||
op_prefix="w13",
|
op_prefix="w13",
|
||||||
lora_a_stacked=self.w1_lora_a_stacked,
|
num_loras=self.max_loras,
|
||||||
lora_b_stacked=self.w1_lora_b_stacked,
|
rank=max_lora_rank,
|
||||||
num_slices=2,
|
num_slices=self.w13_slices,
|
||||||
M=M,
|
M=M,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@ -165,7 +168,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get the block size of m from customized config or default config
|
# get the block size of m from customized config or default config
|
||||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
|
||||||
(
|
(
|
||||||
sorted_token_ids_lora,
|
sorted_token_ids_lora,
|
||||||
expert_ids_lora,
|
expert_ids_lora,
|
||||||
@ -175,7 +177,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
shrink_config["BLOCK_SIZE_M"],
|
shrink_config["BLOCK_SIZE_M"],
|
||||||
self.base_layer.local_num_experts,
|
self.base_layer.local_num_experts,
|
||||||
max_loras,
|
self.max_loras,
|
||||||
self.adapter_enabled,
|
self.adapter_enabled,
|
||||||
expert_map,
|
expert_map,
|
||||||
)
|
)
|
||||||
@ -186,17 +188,15 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
num_tokens_post_padded_lora
|
num_tokens_post_padded_lora
|
||||||
)
|
)
|
||||||
|
|
||||||
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
|
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
|
||||||
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
|
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
|
||||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
#
|
||||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
|
||||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
|
||||||
|
|
||||||
self.punica_wrapper.add_lora_fused_moe(
|
self.punica_wrapper.add_lora_fused_moe(
|
||||||
input.view(-1, top_k, input.shape[-1]),
|
input.view(-1, top_k, input.shape[-1]),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
w13_lora_a_stacked,
|
self.w13_lora_a_stacked,
|
||||||
w13_lora_b_stacked,
|
self.w13_lora_b_stacked,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
sorted_token_ids_lora,
|
sorted_token_ids_lora,
|
||||||
expert_ids_lora,
|
expert_ids_lora,
|
||||||
@ -230,11 +230,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
num_tokens = hidden_states.size(0)
|
num_tokens = hidden_states.size(0)
|
||||||
M = min(num_tokens, CHUNK_SIZE)
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
|
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
|
||||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||||
op_prefix="w2",
|
op_prefix="w2",
|
||||||
lora_a_stacked=self.w2_lora_a_stacked,
|
num_loras=self.max_loras,
|
||||||
lora_b_stacked=self.w2_lora_b_stacked,
|
rank=max_lora_rank,
|
||||||
num_slices=1,
|
num_slices=1,
|
||||||
M=M,
|
M=M,
|
||||||
layer=layer,
|
layer=layer,
|
||||||
@ -247,20 +247,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
num_tokens_post_padded_lora = moe_state_dict[
|
num_tokens_post_padded_lora = moe_state_dict[
|
||||||
"num_tokens_post_padded_lora"
|
"num_tokens_post_padded_lora"
|
||||||
]
|
]
|
||||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
|
||||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
|
||||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
|
||||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||||
intermediate_cache3 = args[0]
|
intermediate_cache3 = args[0]
|
||||||
max_lora_rank = self.w2_lora_a_stacked.shape[-2]
|
|
||||||
|
|
||||||
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
|
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
|
||||||
|
|
||||||
self.punica_wrapper.add_lora_fused_moe(
|
self.punica_wrapper.add_lora_fused_moe(
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
intermediate_cache2,
|
intermediate_cache2,
|
||||||
[self.w2_lora_a_stacked],
|
(self.w2_lora_a_stacked,),
|
||||||
[self.w2_lora_b_stacked],
|
(self.w2_lora_b_stacked,),
|
||||||
topk_weights,
|
topk_weights,
|
||||||
sorted_token_ids_lora,
|
sorted_token_ids_lora,
|
||||||
expert_ids_lora,
|
expert_ids_lora,
|
||||||
@ -289,7 +288,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
fused_experts.moe_sum = moe_sum_decorator(
|
fused_experts.moe_sum = moe_sum_decorator(
|
||||||
self.base_layer, fused_experts.moe_sum
|
self.base_layer, fused_experts.moe_sum
|
||||||
)
|
)
|
||||||
|
|
||||||
self.base_layer.quant_method = FusedMoEModularMethod(
|
self.base_layer.quant_method = FusedMoEModularMethod(
|
||||||
self.base_layer.quant_method, m_fused_moe_fn
|
self.base_layer.quant_method, m_fused_moe_fn
|
||||||
)
|
)
|
||||||
@ -301,33 +299,42 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
model_config: PretrainedConfig | None = None,
|
model_config: PretrainedConfig | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes lora matrices."""
|
"""Initializes lora matrices."""
|
||||||
|
assert self.w13_slices == 2
|
||||||
|
self.max_loras = lora_config.max_loras
|
||||||
self.fully_sharded = lora_config.fully_sharded_loras
|
self.fully_sharded = lora_config.fully_sharded_loras
|
||||||
|
|
||||||
self.adapter_enabled = torch.tensor(
|
self.adapter_enabled = torch.tensor(
|
||||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self.w1_lora_a_stacked = torch.zeros(
|
self.w13_lora_a_stacked = tuple(
|
||||||
(
|
torch.zeros(
|
||||||
max_loras,
|
(
|
||||||
self.base_layer.local_num_experts,
|
max_loras,
|
||||||
lora_config.max_lora_rank
|
self.base_layer.local_num_experts,
|
||||||
if not self.fully_sharded
|
lora_config.max_lora_rank
|
||||||
else divide(lora_config.max_lora_rank, self.tp_size),
|
if not self.fully_sharded
|
||||||
self.base_layer.hidden_size,
|
else divide(lora_config.max_lora_rank, self.tp_size),
|
||||||
),
|
self.base_layer.hidden_size,
|
||||||
dtype=lora_config.lora_dtype,
|
),
|
||||||
device=self.device,
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
for _ in range(self.w13_slices)
|
||||||
)
|
)
|
||||||
self.w1_lora_b_stacked = torch.zeros(
|
|
||||||
(
|
self.w13_lora_b_stacked = tuple(
|
||||||
max_loras,
|
torch.zeros(
|
||||||
self.base_layer.local_num_experts,
|
(
|
||||||
self.base_layer.intermediate_size_per_partition,
|
max_loras,
|
||||||
lora_config.max_lora_rank,
|
self.base_layer.local_num_experts,
|
||||||
),
|
self.base_layer.intermediate_size_per_partition,
|
||||||
dtype=lora_config.lora_dtype,
|
lora_config.max_lora_rank,
|
||||||
device=self.device,
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
for _ in range(self.w13_slices)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.w2_lora_a_stacked = torch.zeros(
|
self.w2_lora_a_stacked = torch.zeros(
|
||||||
@ -353,29 +360,6 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.w3_lora_a_stacked = torch.zeros(
|
|
||||||
(
|
|
||||||
max_loras,
|
|
||||||
self.base_layer.local_num_experts,
|
|
||||||
lora_config.max_lora_rank
|
|
||||||
if not self.fully_sharded
|
|
||||||
else divide(lora_config.max_lora_rank, self.tp_size),
|
|
||||||
self.base_layer.hidden_size,
|
|
||||||
),
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.w3_lora_b_stacked = torch.zeros(
|
|
||||||
(
|
|
||||||
max_loras,
|
|
||||||
self.base_layer.local_num_experts,
|
|
||||||
self.base_layer.intermediate_size_per_partition,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
),
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
||||||
# to create a dummy LoRA weights.
|
# to create a dummy LoRA weights.
|
||||||
self.lora_a_stacked = []
|
self.lora_a_stacked = []
|
||||||
@ -383,20 +367,28 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
for lora_id in range(max_loras):
|
for lora_id in range(max_loras):
|
||||||
for experts_id in range(self.base_layer.local_num_experts):
|
for experts_id in range(self.base_layer.local_num_experts):
|
||||||
# gate_proj,down_proj,up_proj
|
# gate_proj,down_proj,up_proj
|
||||||
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(
|
||||||
|
self.w13_lora_a_stacked[0][lora_id][experts_id]
|
||||||
|
)
|
||||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||||
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
|
self.lora_a_stacked.append(
|
||||||
|
self.w13_lora_a_stacked[1][lora_id][experts_id]
|
||||||
|
)
|
||||||
|
|
||||||
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
|
self.lora_b_stacked.append(
|
||||||
|
self.w13_lora_b_stacked[0][lora_id][experts_id]
|
||||||
|
)
|
||||||
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
|
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
|
||||||
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
|
self.lora_b_stacked.append(
|
||||||
|
self.w13_lora_b_stacked[1][lora_id][experts_id]
|
||||||
|
)
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
"""Resets the lora weights at index back to 0."""
|
"""Resets the lora weights at index back to 0."""
|
||||||
self.w1_lora_a_stacked[index] = 0
|
for pos in range(self.w13_slices):
|
||||||
self.w1_lora_b_stacked[index] = 0
|
self.w13_lora_a_stacked[pos][index] = 0
|
||||||
self.w3_lora_a_stacked[index] = 0
|
self.w13_lora_b_stacked[pos][index] = 0
|
||||||
self.w3_lora_b_stacked[index] = 0
|
|
||||||
self.w2_lora_a_stacked[index] = 0
|
self.w2_lora_a_stacked[index] = 0
|
||||||
self.w2_lora_b_stacked[index] = 0
|
self.w2_lora_b_stacked[index] = 0
|
||||||
self.adapter_enabled[index] = 0
|
self.adapter_enabled[index] = 0
|
||||||
@ -434,7 +426,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
if self.fully_sharded:
|
if self.fully_sharded:
|
||||||
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
|
# Based on S-LoRA, we slice W1 and W3 A along the rank dim,
|
||||||
# and W2 B along the hidden_size dim.
|
# and W2 B along the hidden_size dim.
|
||||||
w13_shard_size = self.w1_lora_a_stacked[index, eid].shape[0]
|
w13_shard_size = self.w13_lora_a_stacked[0][index, eid].shape[0]
|
||||||
w13_start_idx = self.tp_rank * w13_shard_size
|
w13_start_idx = self.tp_rank * w13_shard_size
|
||||||
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
|
w13_end_idx = (self.tp_rank + 1) * w13_shard_size
|
||||||
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
w1_lora_a = w1_lora_a[w13_start_idx:w13_end_idx, :]
|
||||||
@ -444,29 +436,32 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
|||||||
w2_start_idx = self.tp_rank * w2_shard_size
|
w2_start_idx = self.tp_rank * w2_shard_size
|
||||||
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
w2_end_idx = (self.tp_rank + 1) * w2_shard_size
|
||||||
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
w2_lora_b = w2_lora_b[w2_start_idx:w2_end_idx, :]
|
||||||
|
# w1 lora_a
|
||||||
self.w1_lora_a_stacked[
|
self.w13_lora_a_stacked[0][
|
||||||
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
||||||
].copy_(w1_lora_a, non_blocking=True)
|
].copy_(w1_lora_a, non_blocking=True)
|
||||||
|
# w3 lora_a
|
||||||
self.w3_lora_a_stacked[
|
self.w13_lora_a_stacked[1][
|
||||||
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
|
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
|
||||||
].copy_(w3_lora_a, non_blocking=True)
|
].copy_(w3_lora_a, non_blocking=True)
|
||||||
|
|
||||||
|
# w1 lora_b
|
||||||
|
self.w13_lora_b_stacked[0][
|
||||||
|
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
|
||||||
|
].copy_(w1_lora_b, non_blocking=True)
|
||||||
|
# w3 lora_b
|
||||||
|
self.w13_lora_b_stacked[1][
|
||||||
|
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
||||||
|
].copy_(w3_lora_b, non_blocking=True)
|
||||||
|
|
||||||
|
self.w2_lora_a_stacked[
|
||||||
|
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
||||||
|
].copy_(w2_lora_a, non_blocking=True)
|
||||||
|
|
||||||
self.w2_lora_b_stacked[
|
self.w2_lora_b_stacked[
|
||||||
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
||||||
].copy_(w2_lora_b, non_blocking=True)
|
].copy_(w2_lora_b, non_blocking=True)
|
||||||
|
|
||||||
self.w1_lora_b_stacked[
|
|
||||||
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
|
|
||||||
].copy_(w1_lora_b, non_blocking=True)
|
|
||||||
self.w3_lora_b_stacked[
|
|
||||||
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
|
||||||
].copy_(w3_lora_b, non_blocking=True)
|
|
||||||
self.w2_lora_a_stacked[
|
|
||||||
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
|
||||||
].copy_(w2_lora_a, non_blocking=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_replace_layer(
|
def can_replace_layer(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -470,8 +470,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
|||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: list[torch.Tensor],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: list[torch.Tensor],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
expert_ids: torch.Tensor,
|
expert_ids: torch.Tensor,
|
||||||
|
|||||||
@ -360,8 +360,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
|||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
lora_a_stacked: list[torch.Tensor],
|
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||||
lora_b_stacked: list[torch.Tensor],
|
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
expert_ids: torch.Tensor,
|
expert_ids: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user