mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
[Kernel] W8A16 Int8 inside FusedMoE (#7415)
This commit is contained in:
parent
e837b624f2
commit
7fc23be81c
@ -30,19 +30,36 @@ def benchmark_config(
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8: bool,
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
num_iters: int = 100,
|
num_iters: int = 100,
|
||||||
) -> float:
|
) -> float:
|
||||||
init_dtype = torch.float16 if use_fp8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
w1 = torch.randn(num_experts,
|
if use_int8_w8a16:
|
||||||
shard_intermediate_size,
|
w1 = torch.randint(-127,
|
||||||
hidden_size,
|
127, (
|
||||||
dtype=init_dtype)
|
num_experts,
|
||||||
w2 = torch.randn(num_experts,
|
shard_intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
shard_intermediate_size // 2,
|
),
|
||||||
dtype=init_dtype)
|
dtype=torch.int8)
|
||||||
|
w2 = torch.randint(-127,
|
||||||
|
127, (
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
),
|
||||||
|
dtype=torch.int8)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=init_dtype)
|
||||||
|
w2 = torch.randn(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
dtype=init_dtype)
|
||||||
gating_output = torch.randn(num_iters,
|
gating_output = torch.randn(num_iters,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_experts,
|
num_experts,
|
||||||
@ -52,7 +69,11 @@ def benchmark_config(
|
|||||||
w2_scale = None
|
w2_scale = None
|
||||||
a1_scale = None
|
a1_scale = None
|
||||||
a2_scale = None
|
a2_scale = None
|
||||||
if use_fp8:
|
if use_int8_w8a16:
|
||||||
|
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
||||||
|
dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
|
if use_fp8_w8a8:
|
||||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
@ -76,7 +97,8 @@ def benchmark_config(
|
|||||||
renormalize=True,
|
renormalize=True,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
override_config=config,
|
override_config=config,
|
||||||
use_fp8=use_fp8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
@ -155,11 +177,13 @@ class BenchmarkWorker:
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8: bool,
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
) -> Tuple[Dict[str, int], float]:
|
) -> Tuple[Dict[str, int], float]:
|
||||||
torch.cuda.manual_seed_all(self.seed)
|
torch.cuda.manual_seed_all(self.seed)
|
||||||
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
dtype_str = "float8" if use_fp8 else None
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||||
@ -173,7 +197,8 @@ class BenchmarkWorker:
|
|||||||
key=lambda x: abs(x - num_tokens))]
|
key=lambda x: abs(x - num_tokens))]
|
||||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||||
shard_intermediate_size, hidden_size,
|
shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8)
|
topk, dtype, use_fp8_w8a8,
|
||||||
|
use_int8_w8a16)
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
@ -184,9 +209,10 @@ class BenchmarkWorker:
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
topk: int,
|
topk: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_fp8: bool,
|
use_fp8_w8a8: bool,
|
||||||
search_space: List[BenchmarkConfig],
|
use_int8_w8a16: bool,
|
||||||
) -> BenchmarkConfig:
|
search_space: List[Dict[str, int]],
|
||||||
|
) -> Dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
for config in tqdm(search_space):
|
for config in tqdm(search_space):
|
||||||
@ -198,7 +224,8 @@ class BenchmarkWorker:
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
topk,
|
topk,
|
||||||
dtype,
|
dtype,
|
||||||
use_fp8,
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
num_iters=10)
|
num_iters=10)
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
# Some configurations may be invalid and fail to compile.
|
# Some configurations may be invalid and fail to compile.
|
||||||
@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_configs(
|
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||||
configs: Dict[int, BenchmarkConfig],
|
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||||
num_experts: int,
|
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||||
shard_intermediate_size: int,
|
use_int8_w8a16: bool) -> None:
|
||||||
hidden_size: int,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
topk: int,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8=use_fp8_w8a8)
|
||||||
use_fp8: bool,
|
|
||||||
) -> None:
|
|
||||||
dtype_str = "float8" if use_fp8 else None
|
|
||||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
# is the intermediate size after silu_and_mul.
|
# is the intermediate size after silu_and_mul.
|
||||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||||
dtype_str)
|
dtype_str)
|
||||||
|
|
||||||
print(f"Writing best config to {filename}...")
|
print(f"Writing best config to {filename}...")
|
||||||
with open(filename, "w") as f:
|
with open(filename, "w") as f:
|
||||||
json.dump(configs, f, indent=4)
|
json.dump(configs, f, indent=4)
|
||||||
@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
|
|||||||
topk = config.ffn_config.moe_top_k
|
topk = config.ffn_config.moe_top_k
|
||||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
else:
|
else:
|
||||||
# Default: Mixtral.
|
# Default: Mixtral.
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
dtype = config.torch_dtype
|
dtype = config.torch_dtype
|
||||||
use_fp8 = args.dtype == "fp8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
|
||||||
if args.batch_size is None:
|
if args.batch_size is None:
|
||||||
batch_sizes = [
|
batch_sizes = [
|
||||||
@ -294,21 +326,21 @@ def main(args: argparse.Namespace):
|
|||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8, search_space)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
best_configs = {
|
best_configs = {
|
||||||
M: sort_config(config)
|
M: sort_config(config)
|
||||||
for M, config in zip(batch_sizes, configs)
|
for M, config in zip(batch_sizes, configs)
|
||||||
}
|
}
|
||||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||||
topk, dtype, use_fp8)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(f"Tuning took {end - start:.2f} seconds")
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
else:
|
else:
|
||||||
outputs = _distribute("benchmark",
|
outputs = _distribute(
|
||||||
[(batch_size, E, shard_intermediate_size,
|
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
hidden_size, topk, dtype, use_fp8)
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
print(f"Batch size: {batch_size}, config: {config}")
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
@ -323,7 +355,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||||
parser.add_argument("--dtype",
|
parser.add_argument("--dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||||
default="auto")
|
default="auto")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
|||||||
@ -6,9 +6,12 @@ from vllm.worker.model_runner import _get_graph_batch_size
|
|||||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||||
|
|
||||||
|
|
||||||
|
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
|
||||||
|
# TODO: Fix this with trained model
|
||||||
|
@pytest.mark.skip()
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [20])
|
@pytest.mark.parametrize("max_tokens", [10])
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -17,8 +20,6 @@ def test_models(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
# To pass the small model tests, we need full precision.
|
|
||||||
assert dtype == "float"
|
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
@ -36,8 +37,8 @@ def test_models(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@pytest.mark.parametrize("max_tokens", [20])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
def test_batching(
|
def test_batching(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
|
|||||||
28
tests/quantization/test_experts_int8.py
Normal file
28
tests/quantization/test_experts_int8.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
"""Tests experts_int8 quantization startup and generation,
|
||||||
|
doesn't test correctness
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
|
||||||
|
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
|
||||||
|
reason="ExpertsInt8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [10])
|
||||||
|
def test_model_experts_int8_startup(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype,
|
||||||
|
quantization="experts_int8") as vllm_model:
|
||||||
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
@ -243,7 +243,8 @@ class ModelConfig:
|
|||||||
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
|
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
|
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
|
||||||
|
"experts_int8"
|
||||||
]
|
]
|
||||||
tpu_supported_quantization = ["tpu_int8"]
|
tpu_supported_quantization = ["tpu_int8"]
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
|
|||||||
@ -17,42 +17,44 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_kernel(
|
def fused_moe_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
a_ptr,
|
a_ptr,
|
||||||
b_ptr,
|
b_ptr,
|
||||||
c_ptr,
|
c_ptr,
|
||||||
a_scale_ptr,
|
a_scale_ptr,
|
||||||
b_scale_ptr,
|
b_scale_ptr,
|
||||||
topk_weights_ptr,
|
topk_weights_ptr,
|
||||||
sorted_token_ids_ptr,
|
sorted_token_ids_ptr,
|
||||||
expert_ids_ptr,
|
expert_ids_ptr,
|
||||||
num_tokens_post_padded_ptr,
|
num_tokens_post_padded_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
EM,
|
EM,
|
||||||
num_valid_tokens,
|
num_valid_tokens,
|
||||||
# The stride variables represent how much to increase the ptr by when
|
# The stride variables represent how much to increase the ptr by when
|
||||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
# how much to increase `a_ptr` by to get the element one row down
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
# (A has M rows).
|
# (A has M rows).
|
||||||
stride_am,
|
stride_am,
|
||||||
stride_ak,
|
stride_ak,
|
||||||
stride_be,
|
stride_be,
|
||||||
stride_bk,
|
stride_bk,
|
||||||
stride_bn,
|
stride_bn,
|
||||||
stride_cm,
|
stride_cm,
|
||||||
stride_cn,
|
stride_cn,
|
||||||
# Meta-parameters
|
stride_bse,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
stride_bsn,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
# Meta-parameters
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
GROUP_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
top_k: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
compute_type: tl.constexpr,
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
use_fp8: tl.constexpr,
|
top_k: tl.constexpr,
|
||||||
):
|
compute_type: tl.constexpr,
|
||||||
|
use_fp8_w8a8: tl.constexpr,
|
||||||
|
use_int8_w8a16: tl.constexpr):
|
||||||
"""
|
"""
|
||||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||||
token and expert matrices.
|
token and expert matrices.
|
||||||
@ -113,8 +115,12 @@ def fused_moe_kernel(
|
|||||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||||
offs_bn[None, :] * stride_bn)
|
offs_bn[None, :] * stride_bn)
|
||||||
|
if use_int8_w8a16:
|
||||||
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
||||||
|
None, :] * stride_bsn
|
||||||
|
b_scale = tl.load(b_scale_ptrs)
|
||||||
|
|
||||||
if use_fp8:
|
if use_fp8_w8a8:
|
||||||
a_scale = tl.load(a_scale_ptr)
|
a_scale = tl.load(a_scale_ptr)
|
||||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||||
|
|
||||||
@ -136,7 +142,9 @@ def fused_moe_kernel(
|
|||||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||||
other=0.0)
|
other=0.0)
|
||||||
# We accumulate along the K dimension.
|
# We accumulate along the K dimension.
|
||||||
if use_fp8:
|
if use_int8_w8a16:
|
||||||
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||||
|
elif use_fp8_w8a8:
|
||||||
accumulator = tl.dot(a, b, acc=accumulator)
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
else:
|
else:
|
||||||
accumulator += tl.dot(a, b)
|
accumulator += tl.dot(a, b)
|
||||||
@ -149,8 +157,9 @@ def fused_moe_kernel(
|
|||||||
mask=token_mask,
|
mask=token_mask,
|
||||||
other=0)
|
other=0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
if use_int8_w8a16:
|
||||||
if use_fp8:
|
accumulator = (accumulator * b_scale).to(compute_type)
|
||||||
|
elif use_fp8_w8a8:
|
||||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||||
else:
|
else:
|
||||||
accumulator = accumulator.to(compute_type)
|
accumulator = accumulator.to(compute_type)
|
||||||
@ -229,16 +238,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
num_tokens_post_padded: torch.Tensor,
|
num_tokens_post_padded: torch.Tensor,
|
||||||
mul_routed_weight: bool, top_k: int,
|
mul_routed_weight: bool, top_k: int,
|
||||||
config: Dict[str, Any], compute_type: tl.dtype,
|
config: Dict[str, Any], compute_type: tl.dtype,
|
||||||
use_fp8: bool) -> None:
|
use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
if not use_fp8:
|
if use_fp8_w8a8:
|
||||||
assert A_scale is None
|
|
||||||
assert B_scale is None
|
|
||||||
else:
|
|
||||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
assert B_scale is not None
|
||||||
|
else:
|
||||||
|
assert A_scale is None
|
||||||
|
assert B_scale is None
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||||
@ -264,10 +275,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
B.stride(1),
|
B.stride(1),
|
||||||
C.stride(1),
|
C.stride(1),
|
||||||
C.stride(2),
|
C.stride(2),
|
||||||
|
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
|
||||||
|
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8=use_fp8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -426,6 +440,20 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_dtype_str(dtype: torch.dtype,
|
||||||
|
use_int8_w8a16: Optional[bool] = False,
|
||||||
|
use_fp8_w8a8: Optional[bool] = False):
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
return "fp8_w8a8"
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
return "int8_w8a16"
|
||||||
|
elif dtype == torch.float:
|
||||||
|
# avoiding cases where kernel fails when float32 MoE
|
||||||
|
# use fp16/bfloat16 configs
|
||||||
|
return "float32"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(hidden_states: torch.Tensor,
|
def fused_experts(hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
@ -433,7 +461,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
override_config: Optional[Dict[str, Any]] = None,
|
override_config: Optional[Dict[str, Any]] = None,
|
||||||
use_fp8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
@ -454,13 +483,16 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
# https://github.com/vllm-project/vllm/issues/5938
|
# https://github.com/vllm-project/vllm/issues/5938
|
||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
M = min(num_tokens, CHUNK_SIZE)
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
|
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
try_get_optimal_moe_config,
|
try_get_optimal_moe_config,
|
||||||
w1.shape,
|
w1.shape,
|
||||||
w2.shape,
|
w2.shape,
|
||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
"float8" if use_fp8 else None,
|
config_dtype,
|
||||||
override_config=override_config,
|
override_config=override_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -524,7 +556,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
topk_ids.shape[1],
|
topk_ids.shape[1],
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8=use_fp8)
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16)
|
||||||
|
|
||||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||||
|
|
||||||
@ -542,7 +575,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
1,
|
1,
|
||||||
config,
|
config,
|
||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8=use_fp8)
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16)
|
||||||
|
|
||||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -562,7 +596,8 @@ def fused_moe(
|
|||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
use_fp8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
|
use_int8_w8a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
@ -588,7 +623,9 @@ def fused_moe(
|
|||||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||||
note: Deepseekv2 model uses grouped_topk
|
note: Deepseekv2 model uses grouped_topk
|
||||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
|
products for w1 and w2. Defaults to False.
|
||||||
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
products for w1 and w2. Defaults to False.
|
products for w1 and w2. Defaults to False.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
w1.
|
w1.
|
||||||
@ -617,7 +654,8 @@ def fused_moe(
|
|||||||
topk_ids,
|
topk_ids,
|
||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
override_config=override_config,
|
override_config=override_config,
|
||||||
use_fp8=use_fp8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
|
|||||||
@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsConfig)
|
CompressedTensorsConfig)
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
||||||
DeepSpeedFPConfig)
|
DeepSpeedFPConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.experts_int8 import (
|
||||||
|
ExpertsInt8Config)
|
||||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||||
@ -43,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"compressed-tensors": CompressedTensorsConfig,
|
"compressed-tensors": CompressedTensorsConfig,
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
"bitsandbytes": BitsAndBytesConfig,
|
||||||
"qqq": QQQConfig,
|
"qqq": QQQConfig,
|
||||||
|
"experts_int8": ExpertsInt8Config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
175
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
175
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertsInt8Config(QuantizationConfig):
|
||||||
|
"""Config class for Int8 experts quantization."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "experts_int8"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.bfloat16, torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return ExpertsInt8MoEMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, quant_config: ExpertsInt8Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
int8_dtype = torch.int8
|
||||||
|
|
||||||
|
assert 'weight_loader' in extra_weight_attrs
|
||||||
|
weight_loader = extra_weight_attrs['weight_loader']
|
||||||
|
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
|
||||||
|
layer, weight_loader)
|
||||||
|
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||||
|
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=int8_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=int8_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_scale", w13_scale)
|
||||||
|
|
||||||
|
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_scale", w2_scale)
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool = True,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group)
|
||||||
|
|
||||||
|
return fused_experts(x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_int8_w8a16=True,
|
||||||
|
w1_scale=layer.w13_scale,
|
||||||
|
w2_scale=layer.w2_scale)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def quantizing_weight_loader(layer, weight_loader):
|
||||||
|
|
||||||
|
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str, shard_id: int,
|
||||||
|
expert_id: int):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||||
|
device = get_tp_group().device
|
||||||
|
loaded_weight = loaded_weight.to(device)
|
||||||
|
# w1, gate_proj case: Load into first shard of w13.
|
||||||
|
if shard_id == "w1":
|
||||||
|
scales = quantize_in_place_and_get_scales(
|
||||||
|
loaded_weight[shard, :])
|
||||||
|
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
|
||||||
|
0])
|
||||||
|
# w3, up_proj case: Load into second shard of w13.
|
||||||
|
elif shard_id == "w3":
|
||||||
|
scales = quantize_in_place_and_get_scales(
|
||||||
|
loaded_weight[shard, :])
|
||||||
|
layer.w13_scale.data[expert_id, shard_size:2 *
|
||||||
|
shard_size].copy_(scales[:, 0])
|
||||||
|
# w2, down_proj case: Load into only shard of w2.
|
||||||
|
elif shard_id == "w2":
|
||||||
|
scales = quantize_in_place_and_get_scales(loaded_weight[:,
|
||||||
|
shard])
|
||||||
|
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||||
|
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||||
|
expert_id)
|
||||||
|
|
||||||
|
return quantize_and_call_weight_loader
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
vmax = torch.iinfo(torch.int8).max
|
||||||
|
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
|
||||||
|
|
||||||
|
weight.div_(scales)
|
||||||
|
weight.round_()
|
||||||
|
weight.clamp_(-vmax, vmax)
|
||||||
|
|
||||||
|
return scales
|
||||||
@ -488,7 +488,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8=True,
|
use_fp8_w8a8=True,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
w2_scale=layer.w2_weight_scale,
|
w2_scale=layer.w2_weight_scale,
|
||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
|
|||||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class JambaMLP(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: JambaConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
hidden_size = config.hidden_size
|
|
||||||
intermediate_size = config.intermediate_size
|
|
||||||
hidden_act = config.hidden_act
|
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
|
||||||
hidden_size, [intermediate_size] * 2,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config)
|
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
|
||||||
hidden_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config)
|
|
||||||
if hidden_act != "silu":
|
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
||||||
"Only silu is supported for now.")
|
|
||||||
self.act_fn = SiluAndMul()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
|
||||||
x = self.act_fn(gate_up)
|
|
||||||
x, _ = self.down_proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class JambaMoE(nn.Module):
|
class JambaMoE(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
|
|||||||
return hidden_states.view(orig_shape)
|
return hidden_states.view(orig_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class JambaMLP(JambaMoE):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: JambaConfig,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
tp_size: Optional[int] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__(config,
|
||||||
|
num_experts=1,
|
||||||
|
top_k=1,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
tp_size=tp_size,
|
||||||
|
quant_config=quant_config)
|
||||||
|
|
||||||
|
|
||||||
class JambaMambaDecoderLayer(nn.Module):
|
class JambaMambaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -884,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
@ -907,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|||||||
if ".self_attn." in name:
|
if ".self_attn." in name:
|
||||||
name = name.replace(".self_attn", "")
|
name = name.replace(".self_attn", "")
|
||||||
|
|
||||||
|
if "feed_forward" in name and not _is_moe_layer(name):
|
||||||
|
## map MLP layers to expert with ID=0
|
||||||
|
name = name.replace("feed_forward", "feed_forward.experts.0")
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@ -921,16 +906,21 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|||||||
weight_loader(param, loaded_weight, shard_id)
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
for mapping in expert_params_mapping:
|
for (
|
||||||
param_name, weight_name, expert_id, shard_id = mapping
|
param_name,
|
||||||
|
weight_name,
|
||||||
|
expert_id,
|
||||||
|
shard_id,
|
||||||
|
) in expert_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
name,
|
weight_name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id)
|
||||||
break
|
break
|
||||||
@ -943,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_moe_layer(name: str):
|
||||||
|
return any(
|
||||||
|
[experts_name in name for experts_name in [
|
||||||
|
"experts",
|
||||||
|
"router",
|
||||||
|
]])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user