[Kernel] W8A16 Int8 inside FusedMoE (#7415)

This commit is contained in:
Mor Zusman 2024-08-16 20:06:51 +03:00 committed by GitHub
parent e837b624f2
commit 7fc23be81c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 412 additions and 136 deletions

View File

@ -30,19 +30,36 @@ def benchmark_config(
hidden_size: int, hidden_size: int,
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
use_fp8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
w1 = torch.randn(num_experts, if use_int8_w8a16:
shard_intermediate_size, w1 = torch.randint(-127,
hidden_size, 127, (
dtype=init_dtype) num_experts,
w2 = torch.randn(num_experts, shard_intermediate_size,
hidden_size, hidden_size,
shard_intermediate_size // 2, ),
dtype=init_dtype) dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
gating_output = torch.randn(num_iters, gating_output = torch.randn(num_iters,
num_tokens, num_tokens,
num_experts, num_experts,
@ -52,7 +69,11 @@ def benchmark_config(
w2_scale = None w2_scale = None
a1_scale = None a1_scale = None
a2_scale = None a2_scale = None
if use_fp8: if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32) w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32)
@ -76,7 +97,8 @@ def benchmark_config(
renormalize=True, renormalize=True,
inplace=True, inplace=True,
override_config=config, override_config=config,
use_fp8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
@ -155,11 +177,13 @@ class BenchmarkWorker:
hidden_size: int, hidden_size: int,
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
use_fp8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]: ) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed) torch.cuda.manual_seed_all(self.seed)
dtype_str = get_config_dtype_str(dtype,
dtype_str = "float8" if use_fp8 else None use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
@ -173,7 +197,8 @@ class BenchmarkWorker:
key=lambda x: abs(x - num_tokens))] key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts, kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8) topk, dtype, use_fp8_w8a8,
use_int8_w8a16)
return config, kernel_time return config, kernel_time
def tune( def tune(
@ -184,9 +209,10 @@ class BenchmarkWorker:
hidden_size: int, hidden_size: int,
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
use_fp8: bool, use_fp8_w8a8: bool,
search_space: List[BenchmarkConfig], use_int8_w8a16: bool,
) -> BenchmarkConfig: search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
for config in tqdm(search_space): for config in tqdm(search_space):
@ -198,7 +224,8 @@ class BenchmarkWorker:
hidden_size, hidden_size,
topk, topk,
dtype, dtype,
use_fp8, use_fp8_w8a8,
use_int8_w8a16,
num_iters=10) num_iters=10)
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
} }
def save_configs( def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
configs: Dict[int, BenchmarkConfig], shard_intermediate_size: int, hidden_size: int, topk: int,
num_experts: int, dtype: torch.dtype, use_fp8_w8a8: bool,
shard_intermediate_size: int, use_int8_w8a16: bool) -> None:
hidden_size: int, dtype_str = get_config_dtype_str(dtype,
topk: int, use_int8_w8a16=use_int8_w8a16,
dtype: torch.dtype, use_fp8_w8a8=use_fp8_w8a8)
use_fp8: bool,
) -> None:
dtype_str = "float8" if use_fp8 else None
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2, filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str) dtype_str)
print(f"Writing best config to {filename}...") print(f"Writing best config to {filename}...")
with open(filename, "w") as f: with open(filename, "w") as f:
json.dump(configs, f, indent=4) json.dump(configs, f, indent=4)
@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
# Default: Mixtral. # Default: Mixtral.
E = config.num_local_experts E = config.num_local_experts
@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = config.torch_dtype dtype = config.torch_dtype
use_fp8 = args.dtype == "fp8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
if args.batch_size is None: if args.batch_size is None:
batch_sizes = [ batch_sizes = [
@ -294,21 +326,21 @@ def main(args: argparse.Namespace):
start = time.time() start = time.time()
configs = _distribute( configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size, "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8, search_space) topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
best_configs = { best_configs = {
M: sort_config(config) M: sort_config(config)
for M, config in zip(batch_sizes, configs) for M, config in zip(batch_sizes, configs)
} }
save_configs(best_configs, E, shard_intermediate_size, hidden_size, save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8) topk, dtype, use_fp8_w8a8, use_int8_w8a16)
end = time.time() end = time.time()
print(f"Tuning took {end - start:.2f} seconds") print(f"Tuning took {end - start:.2f} seconds")
else: else:
outputs = _distribute("benchmark", outputs = _distribute(
[(batch_size, E, shard_intermediate_size, "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
hidden_size, topk, dtype, use_fp8) topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}") print(f"Batch size: {batch_size}, config: {config}")
@ -323,7 +355,7 @@ if __name__ == "__main__":
parser.add_argument("--tp-size", "-tp", type=int, default=2) parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--dtype", parser.add_argument("--dtype",
type=str, type=str,
choices=["auto", "fp8"], choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto") default="auto")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)

View File

@ -6,9 +6,12 @@ from vllm.worker.model_runner import _get_graph_batch_size
MODELS = ["ai21labs/Jamba-tiny-random"] MODELS = ["ai21labs/Jamba-tiny-random"]
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
# TODO: Fix this with trained model
@pytest.mark.skip()
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [20]) @pytest.mark.parametrize("max_tokens", [10])
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
@ -17,8 +20,6 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
@ -36,8 +37,8 @@ def test_models(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [20]) @pytest.mark.parametrize("max_tokens", [5])
def test_batching( def test_batching(
vllm_runner, vllm_runner,
example_prompts, example_prompts,

View File

@ -0,0 +1,28 @@
# flake8: noqa
"""Tests experts_int8 quantization startup and generation,
doesn't test correctness
"""
import pytest
from tests.quantization.utils import is_quant_method_supported
MODELS = ["ai21labs/Jamba-tiny-random"]
@pytest.mark.skipif(not is_quant_method_supported("experts_int8"),
reason="ExpertsInt8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_model_experts_int8_startup(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with vllm_runner(model, dtype=dtype,
quantization="experts_int8") as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@ -243,7 +243,8 @@ class ModelConfig:
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors" "fbgemm_fp8", "compressed_tensors", "compressed-tensors",
"experts_int8"
] ]
tpu_supported_quantization = ["tpu_int8"] tpu_supported_quantization = ["tpu_int8"]
if self.quantization is not None: if self.quantization is not None:

View File

@ -17,42 +17,44 @@ logger = init_logger(__name__)
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
b_ptr, b_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
b_scale_ptr, b_scale_ptr,
topk_weights_ptr, topk_weights_ptr,
sorted_token_ids_ptr, sorted_token_ids_ptr,
expert_ids_ptr, expert_ids_ptr,
num_tokens_post_padded_ptr, num_tokens_post_padded_ptr,
# Matrix dimensions # Matrix dimensions
N, N,
K, K,
EM, EM,
num_valid_tokens, num_valid_tokens,
# The stride variables represent how much to increase the ptr by when # The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is # moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down # how much to increase `a_ptr` by to get the element one row down
# (A has M rows). # (A has M rows).
stride_am, stride_am,
stride_ak, stride_ak,
stride_be, stride_be,
stride_bk, stride_bk,
stride_bn, stride_bn,
stride_cm, stride_cm,
stride_cn, stride_cn,
# Meta-parameters stride_bse,
BLOCK_SIZE_M: tl.constexpr, stride_bsn,
BLOCK_SIZE_N: tl.constexpr, # Meta-parameters
BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
top_k: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
compute_type: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
use_fp8: tl.constexpr, top_k: tl.constexpr,
): compute_type: tl.constexpr,
use_fp8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
@ -113,8 +115,12 @@ def fused_moe_kernel(
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn) offs_bn[None, :] * stride_bn)
if use_int8_w8a16:
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[
None, :] * stride_bsn
b_scale = tl.load(b_scale_ptrs)
if use_fp8: if use_fp8_w8a8:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts) b_scale = tl.load(b_scale_ptr + off_experts)
@ -136,7 +142,9 @@ def fused_moe_kernel(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
@ -149,8 +157,9 @@ def fused_moe_kernel(
mask=token_mask, mask=token_mask,
other=0) other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_int8_w8a16:
if use_fp8: accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type) accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else: else:
accumulator = accumulator.to(compute_type) accumulator = accumulator.to(compute_type)
@ -229,16 +238,18 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype, config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None: use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
if not use_fp8: if use_fp8_w8a8:
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
elif use_int8_w8a16:
assert B_scale is not None
else:
assert A_scale is None
assert B_scale is None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
@ -264,10 +275,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
B.stride(1), B.stride(1),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
**config, **config,
) )
@ -426,6 +440,20 @@ def grouped_topk(hidden_states: torch.Tensor,
return topk_weights, topk_ids return topk_weights, topk_ids
def get_config_dtype_str(dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False):
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a16:
return "int8_w8a16"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
def fused_experts(hidden_states: torch.Tensor, def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
@ -433,7 +461,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None, override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
@ -454,13 +483,16 @@ def fused_experts(hidden_states: torch.Tensor,
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE) M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
dtype=hidden_states.dtype)
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
w1.shape, w1.shape,
w2.shape, w2.shape,
topk_ids.shape[1], topk_ids.shape[1],
"float8" if use_fp8 else None, config_dtype,
override_config=override_config, override_config=override_config,
) )
@ -524,7 +556,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids.shape[1], topk_ids.shape[1],
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@ -542,7 +575,8 @@ def fused_experts(hidden_states: torch.Tensor,
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16)
torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
@ -562,7 +596,8 @@ def fused_moe(
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
use_fp8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
@ -588,7 +623,9 @@ def fused_moe(
- topk_group: Optional[int]: additional parameter for grouped_topk - topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk - use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1. w1.
@ -617,7 +654,8 @@ def fused_moe(
topk_ids, topk_ids,
inplace=inplace, inplace=inplace,
override_config=override_config, override_config=override_config,
use_fp8=use_fp8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale, w1_scale=w1_scale,
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,

View File

@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig) CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig) DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.experts_int8 import (
ExpertsInt8Config)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
@ -43,6 +45,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig, "bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig, "qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
} }

View File

@ -0,0 +1,175 @@
from typing import Any, Dict, List, Optional
import torch
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
class ExpertsInt8Config(QuantizationConfig):
"""Config class for Int8 experts quantization."""
def __init__(self) -> None:
pass
@classmethod
def get_name(cls) -> str:
return "experts_int8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
return ExpertsInt8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: ExpertsInt8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
int8_dtype = torch.int8
assert 'weight_loader' in extra_weight_attrs
weight_loader = extra_weight_attrs['weight_loader']
wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader(
layer, weight_loader)
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=int8_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_scale = torch.nn.Parameter(torch.zeros(num_experts,
2 * intermediate_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int8_w8a16=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale)
@staticmethod
def quantizing_weight_loader(layer, weight_loader):
def quantize_and_call_weight_loader(param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str, shard_id: int,
expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
shard_size = layer.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
device = get_tp_group().device
loaded_weight = loaded_weight.to(device)
# w1, gate_proj case: Load into first shard of w13.
if shard_id == "w1":
scales = quantize_in_place_and_get_scales(
loaded_weight[shard, :])
layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:,
0])
# w3, up_proj case: Load into second shard of w13.
elif shard_id == "w3":
scales = quantize_in_place_and_get_scales(
loaded_weight[shard, :])
layer.w13_scale.data[expert_id, shard_size:2 *
shard_size].copy_(scales[:, 0])
# w2, down_proj case: Load into only shard of w2.
elif shard_id == "w2":
scales = quantize_in_place_and_get_scales(loaded_weight[:,
shard])
layer.w2_scale.data[expert_id, :].copy_(scales[:, 0])
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
weight_loader(param, loaded_weight, weight_name, shard_id,
expert_id)
return quantize_and_call_weight_loader
def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor:
vmax = torch.iinfo(torch.int8).max
scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax)
weight.div_(scales)
weight.round_()
weight.clamp_(-vmax, vmax)
return scales

View File

@ -488,7 +488,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
use_fp8=True, use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,

View File

@ -16,7 +16,6 @@ from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -249,37 +248,6 @@ class JambaMambaMixer(nn.Module):
return hidden_states return hidden_states
class JambaMLP(nn.Module):
def __init__(
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
hidden_act = config.hidden_act
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class JambaMoE(nn.Module): class JambaMoE(nn.Module):
def __init__(self, def __init__(self,
@ -327,6 +295,21 @@ class JambaMoE(nn.Module):
return hidden_states.view(orig_shape) return hidden_states.view(orig_shape)
class JambaMLP(JambaMoE):
def __init__(self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__(config,
num_experts=1,
top_k=1,
params_dtype=params_dtype,
tp_size=tp_size,
quant_config=quant_config)
class JambaMambaDecoderLayer(nn.Module): class JambaMambaDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
@ -884,8 +867,6 @@ class JambaForCausalLM(nn.Module, HasInnerState):
("qkv_proj", "q_proj", "q"), ("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
] ]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
@ -907,6 +888,10 @@ class JambaForCausalLM(nn.Module, HasInnerState):
if ".self_attn." in name: if ".self_attn." in name:
name = name.replace(".self_attn", "") name = name.replace(".self_attn", "")
if "feed_forward" in name and not _is_moe_layer(name):
## map MLP layers to expert with ID=0
name = name.replace("feed_forward", "feed_forward.experts.0")
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
@ -921,16 +906,21 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
for mapping in expert_params_mapping: for (
param_name, weight_name, expert_id, shard_id = mapping param_name,
weight_name,
expert_id,
shard_id,
) in expert_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
loaded_weight, loaded_weight,
name, weight_name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id) expert_id=expert_id)
break break
@ -943,3 +933,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def _is_moe_layer(name: str):
return any(
[experts_name in name for experts_name in [
"experts",
"router",
]])