mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 16:54:31 +08:00
Load tuned fused_moe_lora shrink and expand kernel configs separately (#27435)
Signed-off-by: Yu Gong <yu3.gong@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
4022a9d279
commit
2ec401bc39
@ -19,13 +19,24 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops import ( ## added fused_moe_lora
|
||||
LoRAKernelMeta,
|
||||
fused_moe_lora_expand,
|
||||
fused_moe_lora_shrink,
|
||||
lora_expand,
|
||||
lora_shrink,
|
||||
)
|
||||
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
|
||||
_LORA_PTR_DICT, ## added _LORA_PTR_DICT for fused_moe_lora
|
||||
)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
@ -59,6 +70,8 @@ DEFAULT_NUM_LORAS = [1, 2, 3, 4]
|
||||
DEFAULT_SORT_BY_LORA_IDS = [False, True]
|
||||
DEFAULT_SEQ_LENGTHS = [1]
|
||||
DEFAULT_EXPAND_FN_ADD_INPUTS = [True, False]
|
||||
DEFAULT_TOP_K_NUMS = [1] # Added for MoE LoRA top_k
|
||||
DEFAULT_NUM_EXPERTS = [8] # Added for MoE LoRA num_experts
|
||||
|
||||
|
||||
# Utilities
|
||||
@ -191,6 +204,11 @@ class OpType(Enum):
|
||||
|
||||
LORA_SHRINK = auto()
|
||||
LORA_EXPAND = auto()
|
||||
## Adding support for fused moe lora
|
||||
FUSED_MOE_LORA_GATE_UP_SHRINK = auto() ## Gate/Up projection variant with shrink
|
||||
FUSED_MOE_LORA_GATE_UP_EXPAND = auto() ## Gate/Up projection variant with expand
|
||||
FUSED_MOE_LORA_DOWN_SHRINK = auto() ## Down projection variant with shrink
|
||||
FUSED_MOE_LORA_DOWN_EXPAND = auto() ## Down projection variant with expand
|
||||
|
||||
@staticmethod
|
||||
def from_str(s: str) -> "OpType":
|
||||
@ -198,6 +216,15 @@ class OpType(Enum):
|
||||
return OpType.LORA_SHRINK
|
||||
if s.lower() == "lora_expand":
|
||||
return OpType.LORA_EXPAND
|
||||
# Adding support for fused moe lora, both in gate_up and down
|
||||
if s.lower() == "fused_moe_lora_gate_up_shrink": ## Gate/Up variant with shrink
|
||||
return OpType.FUSED_MOE_LORA_GATE_UP_SHRINK
|
||||
if s.lower() == "fused_moe_lora_gate_up_expand": ## Gate/Up variant with expand
|
||||
return OpType.FUSED_MOE_LORA_GATE_UP_EXPAND
|
||||
if s.lower() == "fused_moe_lora_down_shrink": ## Down variant with shrink
|
||||
return OpType.FUSED_MOE_LORA_DOWN_SHRINK
|
||||
if s.lower() == "fused_moe_lora_down_expand": ## Down variant with expand
|
||||
return OpType.FUSED_MOE_LORA_DOWN_EXPAND
|
||||
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
||||
|
||||
def is_shrink_fn(self) -> bool:
|
||||
@ -206,19 +233,56 @@ class OpType(Enum):
|
||||
def is_expand_fn(self) -> bool:
|
||||
return self in [OpType.LORA_EXPAND]
|
||||
|
||||
def is_fused_moe_lora_fn(self) -> bool: ## adding for fused MoE LoRA
|
||||
return self in [
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
|
||||
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
|
||||
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
|
||||
]
|
||||
|
||||
def is_fused_moe_lora_gate_up_fn(
|
||||
self,
|
||||
) -> bool: ## adding for fused MoE LoRA Gate/Up
|
||||
return self in [
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
|
||||
]
|
||||
|
||||
def is_fused_moe_lora_down_fn(self) -> bool: ## adding for fused MoE LoRA Down
|
||||
return self in [
|
||||
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
|
||||
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
|
||||
]
|
||||
|
||||
def is_fused_moe_lora_shrink_fn(self) -> bool:
|
||||
return self in [
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
|
||||
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
|
||||
]
|
||||
|
||||
def is_fused_moe_lora_expand_fn(self) -> bool:
|
||||
return self in [
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
|
||||
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
|
||||
]
|
||||
|
||||
def num_slices(self) -> list[int]:
|
||||
if self.is_fused_moe_lora_gate_up_fn():
|
||||
return [2]
|
||||
elif self.is_fused_moe_lora_down_fn():
|
||||
return [1]
|
||||
return [1, 2, 3]
|
||||
|
||||
def mkn(
|
||||
self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int
|
||||
) -> tuple[int, int, int]:
|
||||
num_tokens = batch_size * seq_length
|
||||
if self.is_shrink_fn():
|
||||
if self.is_shrink_fn() or self.is_fused_moe_lora_fn():
|
||||
m = num_tokens
|
||||
k = hidden_size
|
||||
n = lora_rank
|
||||
else:
|
||||
assert self.is_expand_fn()
|
||||
elif self.is_expand_fn():
|
||||
m = num_tokens
|
||||
k = lora_rank
|
||||
n = hidden_size
|
||||
@ -232,9 +296,36 @@ class OpType(Enum):
|
||||
"""
|
||||
if self.is_shrink_fn():
|
||||
return op_dtype, op_dtype, torch.float32
|
||||
else:
|
||||
assert self.is_expand_fn()
|
||||
elif self.is_expand_fn():
|
||||
return torch.float32, op_dtype, op_dtype
|
||||
else:
|
||||
assert self.is_fused_moe_lora_fn()
|
||||
return op_dtype, op_dtype, op_dtype
|
||||
|
||||
def matmul_shapes_fused_moe_lora(
|
||||
self,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
num_loras: int,
|
||||
num_slices: int,
|
||||
top_k_num: int,
|
||||
num_experts: int,
|
||||
) -> tuple[tuple[int], tuple[int], tuple[int], tuple[int]]:
|
||||
if self.is_fused_moe_lora_shrink_fn():
|
||||
input_shape = (
|
||||
(m * top_k_num, n)
|
||||
if self in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
|
||||
else (m, n)
|
||||
)
|
||||
output_shape = (num_slices, m, top_k_num, k)
|
||||
weight_shape = (num_loras, num_experts, k, n)
|
||||
else:
|
||||
assert self.is_fused_moe_lora_expand_fn()
|
||||
input_shape = (num_slices, m, top_k_num, k)
|
||||
output_shape = (m, top_k_num, n * num_slices)
|
||||
weight_shape = (num_loras, num_experts, n, k)
|
||||
return (input_shape, weight_shape, output_shape)
|
||||
|
||||
def matmul_shapes(
|
||||
self,
|
||||
@ -244,6 +335,8 @@ class OpType(Enum):
|
||||
lora_rank: int,
|
||||
num_loras: int,
|
||||
num_slices: int,
|
||||
top_k_num: int | None = None,
|
||||
num_experts: int | None = None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
"""
|
||||
Given num_slices, return the shapes of the A, B, and C matrices
|
||||
@ -258,6 +351,16 @@ class OpType(Enum):
|
||||
if self in [OpType.LORA_EXPAND]:
|
||||
# LoRA expand kernels support num_slices inherently in the kernel
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
if self.is_fused_moe_lora_fn():
|
||||
return self.matmul_shapes_fused_moe_lora(
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
num_loras,
|
||||
num_slices,
|
||||
top_k_num,
|
||||
num_experts,
|
||||
)
|
||||
raise ValueError(f"Unrecognized op_type {self}")
|
||||
|
||||
def bench_fn(self) -> Callable:
|
||||
@ -265,6 +368,16 @@ class OpType(Enum):
|
||||
return lora_shrink
|
||||
if self == OpType.LORA_EXPAND:
|
||||
return lora_expand
|
||||
if self in [
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_SHRINK,
|
||||
OpType.FUSED_MOE_LORA_DOWN_SHRINK,
|
||||
]:
|
||||
return fused_moe_lora_shrink
|
||||
if self in [
|
||||
OpType.FUSED_MOE_LORA_GATE_UP_EXPAND,
|
||||
OpType.FUSED_MOE_LORA_DOWN_EXPAND,
|
||||
]:
|
||||
return fused_moe_lora_expand
|
||||
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
@ -318,6 +431,8 @@ class BenchmarkContext:
|
||||
sort_by_lora_id: bool
|
||||
dtype: torch.dtype
|
||||
seq_length: int | None = None
|
||||
num_experts: int | None = None # num_experts for MoE based ops
|
||||
top_k_num: int | None = None # top_k for MoE based ops
|
||||
num_slices: int | None = None # num_slices for slice based ops
|
||||
|
||||
def with_seq_length(self, seq_length: int) -> "BenchmarkContext":
|
||||
@ -373,6 +488,11 @@ class BenchmarkTensors:
|
||||
f"{dtype_to_str(self.output.dtype)}"
|
||||
)
|
||||
|
||||
def get_num_tokens(self, size: int, top_k_num: int, op_type: OpType):
|
||||
return (
|
||||
size * top_k_num if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK] else size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
ctx: BenchmarkContext, op_type: OpType, device: str = "cuda"
|
||||
@ -385,6 +505,8 @@ class BenchmarkTensors:
|
||||
ctx.lora_rank,
|
||||
ctx.num_loras,
|
||||
ctx.num_slices,
|
||||
ctx.top_k_num,
|
||||
ctx.num_experts,
|
||||
)
|
||||
a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype)
|
||||
input_tensor, lora_weights, output_tensor = make_rand_tensors(
|
||||
@ -432,17 +554,27 @@ class BenchmarkTensors:
|
||||
prompt_lora_indices_tensor,
|
||||
)
|
||||
|
||||
def sanity_check(self) -> None:
|
||||
def sanity_check(self, ctx: BenchmarkContext, op_type: OpType) -> None:
|
||||
"""
|
||||
Fails asserts when non-conformality is detected.
|
||||
"""
|
||||
num_tokens = self.input.shape[-2]
|
||||
num_tokens = (
|
||||
self.input.shape[1]
|
||||
if op_type.is_fused_moe_lora_expand_fn()
|
||||
else self.input.shape[-2]
|
||||
)
|
||||
# check metadata tensors
|
||||
assert torch.sum(self.seq_lens) == num_tokens
|
||||
## In down shrink case, each token is repeated top_k_num times
|
||||
assert num_tokens == self.get_num_tokens(
|
||||
torch.sum(self.seq_lens), ctx.top_k_num, op_type
|
||||
), f"Expected {num_tokens} tokens, but got {torch.sum(self.seq_lens)}"
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
# assert self.seq_start_loc.shape[0] == num_seqs
|
||||
## In down shrink case, each prompt corresponds to top_k_num sequences
|
||||
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
||||
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
|
||||
assert self.get_num_tokens(
|
||||
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
|
||||
)
|
||||
|
||||
def to_device(self, device: str):
|
||||
"""
|
||||
@ -471,21 +603,111 @@ class BenchmarkTensors:
|
||||
to_device(field) if field_name != "no_lora_flag_cpu" else field,
|
||||
)
|
||||
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
def metadata(self, ctx: BenchmarkContext, op_type: OpType) -> tuple[int, int, int]:
|
||||
"""
|
||||
Return num_seqs, num_tokens and max_seq_len
|
||||
"""
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
|
||||
num_tokens = self.get_num_tokens(
|
||||
self.lora_kernel_meta.token_lora_mapping.shape[0], ctx.top_k_num, op_type
|
||||
)
|
||||
max_seq_len = torch.max(self.seq_lens).item()
|
||||
num_slices = len(self.lora_weights_lst)
|
||||
return num_seqs, num_tokens, max_seq_len, num_slices
|
||||
|
||||
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
def fused_moe_lora_data_prepare(
|
||||
self,
|
||||
block_size: int,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
ctx: BenchmarkContext,
|
||||
):
|
||||
def moe_lora_align_block_size(
|
||||
topk_ids: torch.Tensor,
|
||||
token_lora_mapping: torch.Tensor,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be set default to -1 to prevent a blank block
|
||||
expert_ids = torch.empty(
|
||||
(max_loras * max_num_m_blocks,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
num_tokens_post_pad = torch.empty(
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
num_tokens = ctx.batch_size
|
||||
curr_topk_ids = torch.randint(
|
||||
0,
|
||||
ctx.num_experts,
|
||||
(num_tokens, ctx.top_k_num),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
topk_weights = torch.randint(
|
||||
0,
|
||||
ctx.num_experts,
|
||||
(num_tokens, ctx.top_k_num),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
(sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora) = (
|
||||
moe_lora_align_block_size(
|
||||
topk_ids=curr_topk_ids,
|
||||
token_lora_mapping=token_lora_mapping,
|
||||
block_size=block_size,
|
||||
num_experts=ctx.num_experts,
|
||||
max_loras=ctx.num_loras,
|
||||
)
|
||||
)
|
||||
|
||||
sorted_token_ids = sorted_token_ids_lora.view(ctx.num_loras, -1)
|
||||
expert_ids = expert_ids_lora.view(ctx.num_loras, -1)
|
||||
num_tokens_post_padded = num_tokens_post_padded_lora
|
||||
return (topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded)
|
||||
|
||||
def as_lora_shrink_kwargs(
|
||||
self, ctx: BenchmarkContext, op_type: OpType
|
||||
) -> dict[str, Any]:
|
||||
self.sanity_check(ctx, op_type)
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = (
|
||||
@ -520,11 +742,13 @@ class BenchmarkTensors:
|
||||
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
|
||||
}
|
||||
|
||||
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
def as_lora_expand_kwargs(
|
||||
self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool
|
||||
) -> dict[str, Any]:
|
||||
self.sanity_check(ctx, op_type)
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = (
|
||||
@ -561,18 +785,173 @@ class BenchmarkTensors:
|
||||
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
|
||||
}
|
||||
|
||||
def bench_fn_kwargs(
|
||||
self, op_type: OpType, add_inputs: bool | None = None
|
||||
def as_fused_moe_lora_shrink_kwargs(
|
||||
self, ctx: BenchmarkContext, op_type: OpType
|
||||
) -> dict[str, Any]:
|
||||
if op_type.is_shrink_fn():
|
||||
self.sanity_check(ctx, op_type)
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = (
|
||||
self.input.shape,
|
||||
self.lora_weights_lst[0].shape,
|
||||
self.output.shape,
|
||||
)
|
||||
# Expected input shape : [num_tokens, hidden_size] for gate_up
|
||||
# Expected input shape : [top_k_num * num_tokens, hidden_size] for down
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
hidden_size = i_shape[1]
|
||||
# Expected lora weight shape [max_lora, num_experts, lora_rank, hidden_size]
|
||||
assert len(lw_shape) == 4
|
||||
assert lw_shape[-1] == hidden_size
|
||||
lora_rank = lw_shape[-2]
|
||||
# Expected output shape : [num_slices, num_tokens, top_k_num, lora_rank]
|
||||
assert len(o_shape) == 4
|
||||
assert (
|
||||
o_shape
|
||||
== (num_slices, num_tokens // ctx.top_k_num, ctx.top_k_num, lora_rank)
|
||||
if op_type in [OpType.FUSED_MOE_LORA_DOWN_SHRINK]
|
||||
else o_shape == (num_slices, num_tokens, ctx.top_k_num, lora_rank)
|
||||
)
|
||||
kernel_config = get_lora_op_configs(
|
||||
op_type.name.lower(),
|
||||
max_loras=lw_shape[0],
|
||||
batch=num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
rank=lora_rank,
|
||||
num_slices=num_slices,
|
||||
add_inputs=False,
|
||||
)
|
||||
|
||||
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
|
||||
self.fused_moe_lora_data_prepare(
|
||||
block_size=kernel_config["BLOCK_SIZE_M"],
|
||||
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
|
||||
ctx=ctx,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"qcurr_hidden_states": self.input,
|
||||
"lora_a_stacked": self.lora_weights_lst,
|
||||
"a_intermediate_cache1": self.output,
|
||||
"topk_weights": topk_weights,
|
||||
"sorted_token_ids": sorted_token_ids,
|
||||
"expert_ids": expert_ids,
|
||||
"num_tokens_post_padded": num_tokens_post_padded,
|
||||
"top_k_num": ctx.top_k_num,
|
||||
"device": self.input.device,
|
||||
"N": lora_rank,
|
||||
"M": topk_weights.shape[0],
|
||||
"EM": sorted_token_ids.shape[1],
|
||||
"K": self.input.shape[1],
|
||||
"num_tokens": num_tokens,
|
||||
"num_experts": ctx.num_experts,
|
||||
"num_slices": num_slices,
|
||||
"shrink_block_size_m": kernel_config["BLOCK_SIZE_M"],
|
||||
"shrink_block_size_n": kernel_config["BLOCK_SIZE_N"],
|
||||
"shrink_block_size_k": kernel_config["BLOCK_SIZE_K"],
|
||||
"shrink_group_size_m": kernel_config["GROUP_SIZE_M"],
|
||||
"shrink_num_warps": kernel_config["NUM_WARPS"],
|
||||
"shrink_num_stages": kernel_config["NUM_STAGES"],
|
||||
"shrink_split_k": kernel_config.get("SPLIT_K", 1),
|
||||
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
|
||||
}
|
||||
|
||||
def as_fused_moe_lora_expand_kwargs(
|
||||
self, ctx: BenchmarkContext, op_type: OpType
|
||||
) -> dict[str, Any]:
|
||||
self.sanity_check(ctx, op_type)
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata(ctx, op_type)
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = (
|
||||
self.input.shape,
|
||||
self.lora_weights_lst[0].shape,
|
||||
self.output.shape,
|
||||
)
|
||||
|
||||
# Expected input shape : [num_slices, num_tokens, top_k_num, lora_rank]
|
||||
assert len(i_shape) == 4
|
||||
assert i_shape[0] == num_slices
|
||||
assert i_shape[1] == num_tokens
|
||||
lora_rank = i_shape[-1]
|
||||
# Expected lora weight shape : [num_loras, num_experts, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 4
|
||||
assert lw_shape[-1] == lora_rank
|
||||
hidden_size = lw_shape[-2]
|
||||
# Expected output shape : [num_tokens, top_k_num, hidden_size * num_slices]
|
||||
assert len(o_shape) == 3
|
||||
assert o_shape == (num_tokens, ctx.top_k_num, hidden_size * num_slices)
|
||||
|
||||
kernel_config = get_lora_op_configs(
|
||||
op_type.name.lower(),
|
||||
max_loras=lw_shape[0],
|
||||
batch=num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
rank=lora_rank,
|
||||
num_slices=num_slices,
|
||||
add_inputs=False,
|
||||
)
|
||||
|
||||
(topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded) = (
|
||||
self.fused_moe_lora_data_prepare(
|
||||
block_size=kernel_config["BLOCK_SIZE_M"],
|
||||
token_lora_mapping=self.lora_kernel_meta.token_lora_mapping,
|
||||
ctx=ctx,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"a_intermediate_cache1": self.input,
|
||||
"lora_b_stacked": self.lora_weights_lst,
|
||||
"output": self.output,
|
||||
"topk_weights": topk_weights,
|
||||
"sorted_token_ids": sorted_token_ids,
|
||||
"expert_ids": expert_ids,
|
||||
"num_tokens_post_padded": num_tokens_post_padded,
|
||||
"top_k_num": ctx.top_k_num,
|
||||
"device": self.input.device,
|
||||
"N": lora_rank,
|
||||
"M": topk_weights.shape[0],
|
||||
"EM": sorted_token_ids.shape[1],
|
||||
"K": self.input.shape[1],
|
||||
"num_tokens": num_tokens,
|
||||
"num_experts": ctx.num_experts,
|
||||
"num_slices": num_slices,
|
||||
"max_lora_rank": lora_rank,
|
||||
"w1_output_dim_size": lw_shape[2],
|
||||
"expand_block_size_m": kernel_config["BLOCK_SIZE_M"],
|
||||
"expand_block_size_n": kernel_config["BLOCK_SIZE_N"],
|
||||
"expand_block_size_k": kernel_config["BLOCK_SIZE_K"],
|
||||
"expand_group_size_m": kernel_config["GROUP_SIZE_M"],
|
||||
"expand_num_warps": kernel_config["NUM_WARPS"],
|
||||
"expand_num_stages": kernel_config["NUM_STAGES"],
|
||||
"expand_split_k": kernel_config.get("SPLIT_K", 1),
|
||||
"mul_routed_weight": op_type.is_fused_moe_lora_down_fn(),
|
||||
}
|
||||
|
||||
def bench_fn_kwargs(
|
||||
self, ctx: BenchmarkContext, op_type: OpType, add_inputs: bool | None = None
|
||||
) -> dict[str, Any]:
|
||||
if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
|
||||
assert add_inputs is None
|
||||
else:
|
||||
assert add_inputs is not None
|
||||
|
||||
if op_type == OpType.LORA_SHRINK:
|
||||
return self.as_lora_shrink_kwargs()
|
||||
return self.as_lora_shrink_kwargs(ctx, op_type)
|
||||
if op_type == OpType.LORA_EXPAND:
|
||||
return self.as_lora_expand_kwargs(add_inputs)
|
||||
return self.as_lora_expand_kwargs(ctx, op_type, add_inputs)
|
||||
if op_type.is_fused_moe_lora_shrink_fn():
|
||||
return self.as_fused_moe_lora_shrink_kwargs(ctx, op_type)
|
||||
if op_type.is_fused_moe_lora_expand_fn():
|
||||
return self.as_fused_moe_lora_expand_kwargs(ctx, op_type)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def test_correctness(
|
||||
@ -617,7 +996,7 @@ def bench_optype(
|
||||
test_correctness: bool = False,
|
||||
) -> TMeasurement:
|
||||
assert arg_pool_size >= 1
|
||||
if op_type.is_shrink_fn():
|
||||
if op_type.is_shrink_fn() or op_type.is_fused_moe_lora_fn():
|
||||
assert expand_fn_add_inputs is None
|
||||
else:
|
||||
assert expand_fn_add_inputs is not None
|
||||
@ -627,23 +1006,30 @@ def bench_optype(
|
||||
BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)
|
||||
]
|
||||
for bt in bench_tensors:
|
||||
bt.sanity_check()
|
||||
bt.sanity_check(ctx, op_type)
|
||||
|
||||
# Test correctness of our implementation.
|
||||
if test_correctness:
|
||||
assert op_type in [OpType.LORA_SHRINK, OpType.LORA_EXPAND], (
|
||||
f"Correctness testing is not supported for {op_type.name}."
|
||||
)
|
||||
assert all(
|
||||
[bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors]
|
||||
[
|
||||
bt.test_correctness(ctx, op_type, expand_fn_add_inputs)
|
||||
for bt in bench_tensors
|
||||
]
|
||||
)
|
||||
|
||||
# BenchmarkTensors -> dict (kwargs)
|
||||
kwargs_list = [
|
||||
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
|
||||
bt.bench_fn_kwargs(ctx, op_type, add_inputs=expand_fn_add_inputs)
|
||||
for bt in bench_tensors
|
||||
]
|
||||
|
||||
# Clear LoRA optimization hash-maps.
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
_LORA_PTR_DICT.clear()
|
||||
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are set up
|
||||
for kwargs in kwargs_list:
|
||||
op_type.bench_fn()(**kwargs)
|
||||
@ -793,7 +1179,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||
|
||||
# Benchmark bench_op
|
||||
expand_fn_add_inputs = (
|
||||
[None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs
|
||||
[None]
|
||||
if bench_op.is_shrink_fn() or bench_op.is_fused_moe_lora_fn()
|
||||
else args.expand_fn_add_inputs
|
||||
)
|
||||
for add_input_arg in expand_fn_add_inputs:
|
||||
seq_len_timers.append(
|
||||
@ -831,12 +1219,22 @@ def as_benchmark_contexts(
|
||||
hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace
|
||||
) -> list[BenchmarkContext]:
|
||||
ctxs: list[BenchmarkContext] = []
|
||||
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
|
||||
for (
|
||||
batch_size,
|
||||
hidden_size,
|
||||
lora_rank,
|
||||
num_loras,
|
||||
sort_by_lora_id,
|
||||
top_k_num,
|
||||
num_experts,
|
||||
) in product( # noqa
|
||||
args.batch_sizes,
|
||||
list(hidden_sizes),
|
||||
lora_ranks,
|
||||
args.num_loras,
|
||||
args.sort_by_lora_id,
|
||||
args.top_k_nums,
|
||||
args.num_experts,
|
||||
):
|
||||
ctxs.append(
|
||||
BenchmarkContext(
|
||||
@ -851,6 +1249,8 @@ def as_benchmark_contexts(
|
||||
seq_length=None,
|
||||
sort_by_lora_id=sort_by_lora_id,
|
||||
dtype=args.dtype,
|
||||
top_k_num=top_k_num,
|
||||
num_experts=num_experts,
|
||||
# To be filled based on the OpType to benchmark
|
||||
num_slices=None,
|
||||
)
|
||||
@ -1012,6 +1412,22 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--top-k-nums",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TOP_K_NUMS,
|
||||
help="Top-K values for MoE LoRA operations",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--num-experts",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_NUM_EXPERTS,
|
||||
help="Number of experts for MoE LoRA operations",
|
||||
)
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description=f"""
|
||||
Benchmark LoRA kernels:
|
||||
|
||||
@ -158,6 +158,8 @@ def use_fused_moe_lora_kernel(
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"NUM_WARPS": 4,
|
||||
"NUM_STAGES": 3,
|
||||
"SPLIT_K": 1,
|
||||
}
|
||||
|
||||
@ -182,6 +184,15 @@ def use_fused_moe_lora_kernel(
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config["NUM_WARPS"],
|
||||
config["NUM_STAGES"],
|
||||
config["SPLIT_K"],
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config["NUM_WARPS"],
|
||||
config["NUM_STAGES"],
|
||||
config["SPLIT_K"],
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
@ -39,6 +40,64 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
self.device = base_layer.w2_weight.device
|
||||
self._inject_lora_into_fused_moe()
|
||||
|
||||
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
||||
normalized_config = {}
|
||||
for key, value in config.items():
|
||||
if key.islower():
|
||||
if key.startswith("block_"):
|
||||
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
|
||||
else:
|
||||
normalized_key = key.upper()
|
||||
else:
|
||||
normalized_key = key
|
||||
normalized_config[normalized_key] = value
|
||||
return normalized_config
|
||||
|
||||
def _get_lora_moe_configs(
|
||||
self,
|
||||
op_prefix: str,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
num_slices: int,
|
||||
M: int,
|
||||
layer: FusedMoE,
|
||||
top_k: int,
|
||||
config_dtype: str,
|
||||
):
|
||||
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
||||
shrink_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
||||
max_loras=lora_a_stacked.shape[0],
|
||||
batch=M,
|
||||
hidden_size=lora_a_stacked.shape[-1],
|
||||
rank=lora_a_stacked.shape[-2],
|
||||
num_slices=num_slices,
|
||||
moe_intermediate_size=lora_b_stacked.shape[-2],
|
||||
)
|
||||
expand_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_expand",
|
||||
max_loras=lora_a_stacked.shape[0],
|
||||
batch=M,
|
||||
hidden_size=lora_a_stacked.shape[-1],
|
||||
rank=lora_a_stacked.shape[-2],
|
||||
num_slices=num_slices,
|
||||
moe_intermediate_size=lora_b_stacked.shape[-2],
|
||||
)
|
||||
else: # fall back to the default config
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
shrink_config = get_config_func(M)
|
||||
expand_config = get_config_func(M)
|
||||
shrink_config = self._normalize_keys(shrink_config)
|
||||
expand_config = self._normalize_keys(expand_config)
|
||||
return shrink_config, expand_config
|
||||
|
||||
def _inject_lora_into_fused_moe(self):
|
||||
moe_state_dict = {}
|
||||
top_k = self.base_layer.top_k
|
||||
@ -90,17 +149,19 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w13",
|
||||
lora_a_stacked=self.w1_lora_a_stacked,
|
||||
lora_b_stacked=self.w1_lora_b_stacked,
|
||||
num_slices=2,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
config_dtype=config_dtype,
|
||||
)
|
||||
|
||||
# get the block size of m from customized config or default config
|
||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
||||
config = get_config_func(M)
|
||||
(
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
@ -108,7 +169,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
) = self.punica_wrapper.moe_lora_align_block_size(
|
||||
curr_topk_ids,
|
||||
num_tokens,
|
||||
config["BLOCK_SIZE_M"],
|
||||
shrink_config["BLOCK_SIZE_M"],
|
||||
self.base_layer.local_num_experts,
|
||||
max_loras,
|
||||
self.adapter_enabled,
|
||||
@ -138,7 +199,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
config,
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
)
|
||||
|
||||
@ -164,17 +226,17 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w2",
|
||||
lora_a_stacked=self.w2_lora_a_stacked,
|
||||
lora_b_stacked=self.w2_lora_b_stacked,
|
||||
num_slices=1,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
config_dtype=config_dtype,
|
||||
)
|
||||
|
||||
config = get_config_func(M)
|
||||
|
||||
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
|
||||
expert_ids_lora = moe_state_dict["expert_ids_lora"]
|
||||
num_tokens_post_padded_lora = moe_state_dict[
|
||||
@ -197,7 +259,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
config,
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
True,
|
||||
)
|
||||
|
||||
@ -44,8 +44,17 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA
|
||||
|
||||
For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.
|
||||
|
||||
For `fused_moe_lora_w13_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json`.
|
||||
|
||||
For `fused_moe_lora_w13_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json`.
|
||||
|
||||
For `fused_moe_lora_w2_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json`.
|
||||
|
||||
For `fused_moe_lora_w2_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json`.
|
||||
|
||||
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`
|
||||
|
||||
### Json Structure
|
||||
|
||||
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`
|
||||
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]`
|
||||
where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer.
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora
|
||||
|
||||
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
|
||||
fused_moe_lora,
|
||||
fused_moe_lora_expand,
|
||||
fused_moe_lora_shrink,
|
||||
)
|
||||
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
|
||||
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
|
||||
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
|
||||
@ -11,4 +16,6 @@ __all__ = [
|
||||
"lora_shrink",
|
||||
"LoRAKernelMeta",
|
||||
"fused_moe_lora",
|
||||
"fused_moe_lora_shrink",
|
||||
"fused_moe_lora_expand",
|
||||
]
|
||||
|
||||
@ -176,88 +176,50 @@ def _fused_moe_lora_kernel(
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
def _fused_moe_lora_shrink(
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
# (num_slices, num_tokens, top_k_num, max_lora_rank)
|
||||
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
|
||||
lora_a_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, N, max_lora_rank,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
## adding for kernel
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||
assert (
|
||||
sorted_token_ids.dim()
|
||||
== expert_ids.dim()
|
||||
== topk_weights.dim()
|
||||
== qcurr_hidden_states.dim()
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
sorted_token_ids.shape[0]
|
||||
== expert_ids.shape[0]
|
||||
== num_tokens_post_padded.shape[0]
|
||||
)
|
||||
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
|
||||
assert output.shape[0] == topk_weights.shape[0]
|
||||
assert top_k_num == topk_weights.shape[1]
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
|
||||
for lora_a, lora_b in zip(lora_a_stacked, lora_b_stacked):
|
||||
assert lora_a.dtype == lora_b.dtype == output.dtype == qcurr_hidden_states.dtype
|
||||
assert lora_a.dtype in [torch.float16, torch.bfloat16]
|
||||
|
||||
device = qcurr_hidden_states.device
|
||||
num_slices = len(lora_a_stacked)
|
||||
|
||||
config = {
|
||||
shrink_config = {
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"SPLIT_K": split_k,
|
||||
}
|
||||
|
||||
w1_lora_a_stacked = lora_a_stacked[0]
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
num_experts = lora_a_stacked[0].shape[1]
|
||||
|
||||
N = max_lora_rank
|
||||
M = topk_weights.shape[0]
|
||||
EM = sorted_token_ids.shape[1]
|
||||
K = qcurr_hidden_states.shape[1]
|
||||
num_tokens = M * top_k_num
|
||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||
|
||||
lora_intermediate_cache1 = torch.zeros(
|
||||
(num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# slices
|
||||
a_intermediate_size = num_slices * M * top_k_num * max_lora_rank
|
||||
a_intermediate_cache1 = lora_intermediate_cache1[:a_intermediate_size].view(
|
||||
num_slices, M, top_k_num, max_lora_rank
|
||||
)
|
||||
b_intermediate_cache1 = lora_intermediate_cache1[a_intermediate_size:].view(
|
||||
num_slices, M, top_k_num, w1_output_dim_size
|
||||
)
|
||||
|
||||
b_ptr = _get_ptr(lora_a_stacked, device)
|
||||
|
||||
grid = lambda META: (
|
||||
@ -299,19 +261,70 @@ def _fused_moe_lora(
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
**config,
|
||||
**shrink_config,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora_expand(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
## adding for kernel
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
max_lora_rank: int,
|
||||
w1_output_dim_size: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
b_ptr = _get_ptr(lora_b_stacked, device)
|
||||
K = max_lora_rank
|
||||
N = w1_output_dim_size
|
||||
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
|
||||
a_intermediate_cache1 = a_intermediate_cache1.view(
|
||||
-1, a_intermediate_cache1.shape[3]
|
||||
)
|
||||
|
||||
# Set split_k = 1 for expand calls
|
||||
config["SPLIT_K"] = 1
|
||||
b_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, w1_output_dim_size),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
expand_config = {
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
|
||||
}
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
len(lora_b_stacked),
|
||||
@ -348,12 +361,142 @@ def _fused_moe_lora(
|
||||
num_slice_c=num_slices,
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
**config,
|
||||
**expand_config,
|
||||
)
|
||||
for i in range(num_slices):
|
||||
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _fused_moe_lora(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
|
||||
lora_a_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, N, max_lora_rank,),...]
|
||||
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
|
||||
sorted_token_ids: torch.Tensor, # (max_loras, _)
|
||||
expert_ids: torch.Tensor, # (max_loras, _ ,)
|
||||
num_tokens_post_padded: torch.Tensor, # (max_loras, )
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
shrink_block_size_k: int,
|
||||
shrink_group_size_m: int,
|
||||
shrink_num_warps: int,
|
||||
shrink_num_stages: int,
|
||||
shrink_split_k: int,
|
||||
expand_block_size_m: int,
|
||||
expand_block_size_n: int,
|
||||
expand_block_size_k: int,
|
||||
expand_group_size_m: int,
|
||||
expand_num_warps: int,
|
||||
expand_num_stages: int,
|
||||
expand_split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
|
||||
assert (
|
||||
sorted_token_ids.dim()
|
||||
== expert_ids.dim()
|
||||
== topk_weights.dim()
|
||||
== qcurr_hidden_states.dim()
|
||||
== 2
|
||||
)
|
||||
assert (
|
||||
sorted_token_ids.shape[0]
|
||||
== expert_ids.shape[0]
|
||||
== num_tokens_post_padded.shape[0]
|
||||
)
|
||||
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
|
||||
assert output.shape[0] == topk_weights.shape[0]
|
||||
assert top_k_num == topk_weights.shape[1]
|
||||
device = qcurr_hidden_states.device
|
||||
num_slices = len(lora_a_stacked)
|
||||
w1_lora_b_stacked = lora_b_stacked[0]
|
||||
num_experts = lora_a_stacked[0].shape[1]
|
||||
N = max_lora_rank
|
||||
M = topk_weights.shape[0]
|
||||
EM = sorted_token_ids.shape[1]
|
||||
K = qcurr_hidden_states.shape[1]
|
||||
num_tokens = M * top_k_num
|
||||
w1_output_dim_size = w1_lora_b_stacked.shape[2]
|
||||
|
||||
a_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, max_lora_rank),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
_fused_moe_lora_shrink(
|
||||
a_intermediate_cache1,
|
||||
qcurr_hidden_states,
|
||||
lora_a_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
## adding for kernel
|
||||
device,
|
||||
N,
|
||||
M,
|
||||
EM,
|
||||
K,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
num_slices,
|
||||
shrink_block_size_m,
|
||||
shrink_block_size_n,
|
||||
shrink_block_size_k,
|
||||
shrink_group_size_m,
|
||||
shrink_num_warps,
|
||||
shrink_num_stages,
|
||||
shrink_split_k,
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
_fused_moe_lora_expand(
|
||||
output,
|
||||
a_intermediate_cache1,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
## adding for kernel
|
||||
device,
|
||||
N,
|
||||
M,
|
||||
EM,
|
||||
K,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
num_slices,
|
||||
max_lora_rank,
|
||||
w1_output_dim_size,
|
||||
expand_block_size_m,
|
||||
expand_block_size_n,
|
||||
expand_block_size_k,
|
||||
expand_group_size_m,
|
||||
expand_num_warps,
|
||||
expand_num_stages,
|
||||
expand_split_k,
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
|
||||
def _fused_moe_lora_fake(
|
||||
output: torch.Tensor,
|
||||
qcurr_hidden_states: torch.Tensor,
|
||||
@ -367,10 +510,84 @@ def _fused_moe_lora_fake(
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
shrink_block_size_m: int,
|
||||
shrink_block_size_n: int,
|
||||
shrink_block_size_k: int,
|
||||
shrink_group_size_m: int,
|
||||
shrink_num_warps: int,
|
||||
shrink_num_stages: int,
|
||||
shrink_split_k: int,
|
||||
expand_block_size_m: int,
|
||||
expand_block_size_n: int,
|
||||
expand_block_size_k: int,
|
||||
expand_group_size_m: int,
|
||||
expand_num_warps: int,
|
||||
expand_num_stages: int,
|
||||
expand_split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _fused_moe_lora_shrink_fake(
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
qcurr_hidden_states: torch.Tensor,
|
||||
lora_a_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _fused_moe_lora_expand_fake(
|
||||
output: torch.Tensor,
|
||||
a_intermediate_cache1: torch.Tensor,
|
||||
lora_b_stacked: list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
top_k_num: int,
|
||||
lora_ids: torch.Tensor,
|
||||
adapter_enabled: torch.Tensor,
|
||||
device: torch.device,
|
||||
N: int,
|
||||
M: int,
|
||||
EM: int,
|
||||
K: int,
|
||||
num_tokens: int,
|
||||
num_experts: int,
|
||||
num_slices: int,
|
||||
max_lora_rank: int,
|
||||
w1_output_dim_size: int,
|
||||
block_size_m: int,
|
||||
block_size_n: int,
|
||||
block_size_k: int,
|
||||
group_size_m: int,
|
||||
num_warps: int,
|
||||
num_stages: int,
|
||||
split_k: int,
|
||||
mul_routed_weight: bool = False,
|
||||
) -> None:
|
||||
return
|
||||
@ -383,7 +600,26 @@ try:
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fused_moe_lora_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_moe_lora_shrink",
|
||||
op_func=_fused_moe_lora_shrink,
|
||||
mutates_args=["a_intermediate_cache1"],
|
||||
fake_impl=_fused_moe_lora_shrink_fake,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_moe_lora_expand",
|
||||
op_func=_fused_moe_lora_expand,
|
||||
mutates_args=["output"],
|
||||
fake_impl=_fused_moe_lora_expand_fake,
|
||||
)
|
||||
|
||||
fused_moe_lora = torch.ops.vllm.fused_moe_lora
|
||||
fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink
|
||||
fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand
|
||||
|
||||
except AttributeError:
|
||||
fused_moe_lora = _fused_moe_lora
|
||||
fused_moe_lora_shrink = _fused_moe_lora_shrink
|
||||
fused_moe_lora_expand = _fused_moe_lora_expand
|
||||
|
||||
@ -154,13 +154,13 @@ def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
|
||||
gpu_name = gpu_name.replace("-", "_")
|
||||
|
||||
config_fname = None
|
||||
if op_type == "shrink":
|
||||
config_fname = f"{gpu_name}_{op_type.upper()}.json"
|
||||
else:
|
||||
assert op_type == "expand"
|
||||
# only expand op needs to consider add_inputs
|
||||
if op_type == "expand":
|
||||
config_fname = (
|
||||
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
|
||||
)
|
||||
else:
|
||||
config_fname = f"{gpu_name}_{op_type.upper()}.json"
|
||||
|
||||
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
|
||||
if not config_path.exists():
|
||||
@ -186,8 +186,17 @@ def get_lora_op_configs(
|
||||
rank: int,
|
||||
num_slices: int,
|
||||
add_inputs: bool | None = None,
|
||||
moe_intermediate_size: int | None = None,
|
||||
) -> dict[str, int | None]:
|
||||
assert op_type in ["shrink", "expand"]
|
||||
# Add support for fused_moe_lora ops
|
||||
assert op_type in [
|
||||
"shrink",
|
||||
"expand",
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]
|
||||
|
||||
# default config
|
||||
default = {}
|
||||
@ -203,6 +212,22 @@ def get_lora_op_configs(
|
||||
"num_stages": 2,
|
||||
"max_nreg": None,
|
||||
}
|
||||
# The default config for fused_moe_lora ops
|
||||
elif op_type in [
|
||||
"fused_moe_lora_w13_shrink",
|
||||
"fused_moe_lora_w13_expand",
|
||||
"fused_moe_lora_w2_shrink",
|
||||
"fused_moe_lora_w2_expand",
|
||||
]:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
"block_n": 64,
|
||||
"block_k": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
"group_size_m": 8,
|
||||
"split_k": 1,
|
||||
}
|
||||
else:
|
||||
default = {
|
||||
"block_m": 64,
|
||||
@ -247,5 +272,13 @@ def get_lora_op_configs(
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
|
||||
)
|
||||
|
||||
# slice by moe-intermediate-size if applicable
|
||||
if moe_intermediate_size is not None:
|
||||
i = moe_intermediate_size
|
||||
config_data = (
|
||||
config_data.get(str(i))
|
||||
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))]
|
||||
)
|
||||
|
||||
assert config_data is not None
|
||||
return config_data
|
||||
|
||||
@ -479,7 +479,8 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
config,
|
||||
shrink_config,
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
|
||||
@ -367,7 +367,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
config,
|
||||
shrink_config,
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
):
|
||||
@ -388,10 +389,19 @@ class PunicaWrapperGPU(PunicaWrapperBase):
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
config["GROUP_SIZE_M"],
|
||||
config.get("SPLIT_K", 1),
|
||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||
shrink_config.get("BLOCK_SIZE_K", 32),
|
||||
shrink_config.get("GROUP_SIZE_M", 8),
|
||||
shrink_config.get("NUM_WARPS", 4),
|
||||
shrink_config.get("NUM_STAGES", 3),
|
||||
shrink_config.get("SPLIT_K", 1),
|
||||
expand_config.get("BLOCK_SIZE_M", 64),
|
||||
expand_config.get("BLOCK_SIZE_N", 64),
|
||||
expand_config.get("BLOCK_SIZE_K", 32),
|
||||
expand_config.get("GROUP_SIZE_M", 8),
|
||||
expand_config.get("NUM_WARPS", 4),
|
||||
expand_config.get("NUM_STAGES", 3),
|
||||
expand_config.get("SPLIT_K", 1),
|
||||
mul_routed_weight,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user