mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Kernel] W8A16 Int8 inside FusedMoE (#7415)
This commit is contained in:
parent
e837b624f2
commit
7fc23be81c
@ -30,19 +30,36 @@ def benchmark_config(
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8 else dtype
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
w1 = torch.randn(num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
dtype=init_dtype)
|
||||
w2 = torch.randn(num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
dtype=init_dtype)
|
||||
if use_int8_w8a16:
|
||||
w1 = torch.randint(-127,
|
||||
127, (
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
w2 = torch.randint(-127,
|
||||
127, (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
else:
|
||||
w1 = torch.randn(num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
dtype=init_dtype)
|
||||
w2 = torch.randn(num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
dtype=init_dtype)
|
||||
gating_output = torch.randn(num_iters,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
@ -52,7 +69,11 @@ def benchmark_config(
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_fp8:
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
@ -76,7 +97,8 @@ def benchmark_config(
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
override_config=config,
|
||||
use_fp8=use_fp8,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
@ -155,11 +177,13 @@ class BenchmarkWorker:
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(self.seed)
|
||||
|
||||
dtype_str = "float8" if use_fp8 else None
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||
@ -173,7 +197,8 @@ class BenchmarkWorker:
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||
shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8)
|
||||
topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -184,9 +209,10 @@ class BenchmarkWorker:
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
search_space: List[BenchmarkConfig],
|
||||
) -> BenchmarkConfig:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
@ -198,7 +224,8 @@ class BenchmarkWorker:
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=10)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
) -> None:
|
||||
dtype_str = "float8" if use_fp8 else None
|
||||
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool) -> None:
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
use_fp8 = args.dtype == "fp8"
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
@ -294,21 +326,21 @@ def main(args: argparse.Namespace):
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8, search_space)
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
||||
for batch_size in batch_sizes])
|
||||
best_configs = {
|
||||
M: sort_config(config)
|
||||
for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8)
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute("benchmark",
|
||||
[(batch_size, E, shard_intermediate_size,
|
||||
hidden_size, topk, dtype, use_fp8)
|
||||
for batch_size in batch_sizes])
|
||||
outputs = _distribute(
|
||||
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
@ -323,7 +355,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
|
||||
@ -6,9 +6,12 @@ from vllm.worker.model_runner import _get_graph_batch_size
|
||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||
|
||||
|
||||
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
|
||||
# TODO: Fix this with trained model
|
||||
@pytest.mark.skip()
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -17,8 +20,6 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# To pass the small model tests, we need full precision.
|
||||
assert dtype == "float"
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
@ -36,8 +37,8 @@ def test_models(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [20])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
def test_batching(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
|
||||
28
tests/quantization/test_experts_int8.py
Normal file
28
tests/quantization/test_experts_int8.py
Normal file
@ -0,0 +1,28 @@
|
||||
# flake8: noqa
|
||||
"""Tests experts_int8 quantization startup and generation,
|
||||
doesn't test correctness
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
MODELS = ["ai21labs/Jamba-tiny-random"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
|
||||
reason="ExpertsInt8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [10])
|
||||
def test_model_experts_int8_startup(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
quantization="experts_int8") as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
@ -243,7 +243,8 @@ class ModelConfig:
|
||||
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
|
||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors",
|
||||
"experts_int8"
|
||||
]
|
||||
tpu_supported_quantization = ["tpu_int8"]
|
||||
if self.quantization is not None:
|
||||
|
||||
@ -17,42 +17,44 @@ logger = init_logger(__name__)
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8: tl.constexpr,
|
||||
):
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_scale_ptr,
|
||||
b_scale_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
K,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_bse,
|
||||
stride_bsn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
use_fp8_w8a8: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
@ -113,8 +115,12 @@ def fused_moe_kernel(
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
if use_int8_w8a16:
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
|
||||
None, :] * stride_bsn
|
||||
b_scale = tl.load(b_scale_ptrs)
|
||||
|
||||
if use_fp8:
|
||||
if use_fp8_w8a8:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
b_scale = tl.load(b_scale_ptr + off_experts)
|
||||
|
||||
@ -136,7 +142,9 @@ def fused_moe_kernel(
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_fp8:
|
||||
if use_int8_w8a16:
|
||||
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
||||
elif use_fp8_w8a8:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
else:
|
||||
accumulator += tl.dot(a, b)
|
||||
@ -149,8 +157,9 @@ def fused_moe_kernel(
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
if use_fp8:
|
||||
if use_int8_w8a16:
|
||||
accumulator = (accumulator * b_scale).to(compute_type)
|
||||
elif use_fp8_w8a8:
|
||||
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
|
||||
else:
|
||||
accumulator = accumulator.to(compute_type)
|
||||
@ -229,16 +238,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool, top_k: int,
|
||||
config: Dict[str, Any], compute_type: tl.dtype,
|
||||
use_fp8: bool) -> None:
|
||||
use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if not use_fp8:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
else:
|
||||
if use_fp8_w8a8:
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
elif use_int8_w8a16:
|
||||
assert B_scale is not None
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
@ -264,10 +275,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
@ -426,6 +440,20 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -433,7 +461,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
@ -454,13 +483,16 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
# https://github.com/vllm-project/vllm/issues/5938
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
w1.shape,
|
||||
w2.shape,
|
||||
topk_ids.shape[1],
|
||||
"float8" if use_fp8 else None,
|
||||
config_dtype,
|
||||
override_config=override_config,
|
||||
)
|
||||
|
||||
@ -524,7 +556,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16)
|
||||
|
||||
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
@ -542,7 +575,8 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16)
|
||||
|
||||
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
@ -562,7 +596,8 @@ def fused_moe(
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
use_fp8: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
@ -588,7 +623,9 @@ def fused_moe(
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
@ -617,7 +654,8 @@ def fused_moe(
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
override_config=override_config,
|
||||
use_fp8=use_fp8,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
|
||||
@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsConfig)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import (
|
||||
DeepSpeedFPConfig)
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import (
|
||||
ExpertsInt8Config)
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
@ -43,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
}
|
||||
|
||||
|
||||
|
||||
175
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
175
vllm/model_executor/layers/quantization/experts_int8.py
Normal file
@ -0,0 +1,175 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class ExpertsInt8Config(QuantizationConfig):
|
||||
"""Config class for Int8 experts quantization."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return ExpertsInt8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, quant_config: ExpertsInt8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
int8_dtype = torch.int8
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
dtype=int8_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
2 * intermediate_size,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group)
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int8_w8a16=True,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale)
|
||||
|
||||
@staticmethod
|
||||
def quantizing_weight_loader(layer, weight_loader):
|
||||
|
||||
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: int,
|
||||
expert_id: int):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
device = get_tp_group().device
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
# w1, gate_proj case: Load into first shard of w13.
|
||||
if shard_id == "w1":
|
||||
scales = quantize_in_place_and_get_scales(
|
||||
loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
|
||||
0])
|
||||
# w3, up_proj case: Load into second shard of w13.
|
||||
elif shard_id == "w3":
|
||||
scales = quantize_in_place_and_get_scales(
|
||||
loaded_weight[shard, :])
|
||||
layer.w13_scale.data[expert_id, shard_size:2 *
|
||||
shard_size].copy_(scales[:, 0])
|
||||
# w2, down_proj case: Load into only shard of w2.
|
||||
elif shard_id == "w2":
|
||||
scales = quantize_in_place_and_get_scales(loaded_weight[:,
|
||||
shard])
|
||||
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Shard id must be in [0,1,2] but got {shard_id}")
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||
expert_id)
|
||||
|
||||
return quantize_and_call_weight_loader
|
||||
|
||||
|
||||
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
|
||||
vmax = torch.iinfo(torch.int8).max
|
||||
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
|
||||
|
||||
weight.div_(scales)
|
||||
weight.round_()
|
||||
weight.clamp_(-vmax, vmax)
|
||||
|
||||
return scales
|
||||
@ -488,7 +488,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8=True,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
|
||||
@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JambaMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
hidden_act = config.hidden_act
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class JambaMoE(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
|
||||
return hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
class JambaMLP(JambaMoE):
|
||||
|
||||
def __init__(self,
|
||||
config: JambaConfig,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__(config,
|
||||
num_experts=1,
|
||||
top_k=1,
|
||||
params_dtype=params_dtype,
|
||||
tp_size=tp_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
|
||||
class JambaMambaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
@ -884,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
@ -907,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
if ".self_attn." in name:
|
||||
name = name.replace(".self_attn", "")
|
||||
|
||||
if "feed_forward" in name and not _is_moe_layer(name):
|
||||
## map MLP layers to expert with ID=0
|
||||
name = name.replace("feed_forward", "feed_forward.experts.0")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@ -921,16 +906,21 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
for (
|
||||
param_name,
|
||||
weight_name,
|
||||
expert_id,
|
||||
shard_id,
|
||||
) in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
@ -943,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
def _is_moe_layer(name: str):
|
||||
return any(
|
||||
[experts_name in name for experts_name in [
|
||||
"experts",
|
||||
"router",
|
||||
]])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user