From 19bee6d12d985c231b16374c99836376fc0c5706 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Wed, 3 Dec 2025 13:04:59 -0500 Subject: [PATCH] [Performance][DP/EP] Add silu_mul_per_token_group_quant_fp8_colmajor kernel (#29470) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath Co-authored-by: Tyler Michael Smith --- .../benchmark_2d_silu_mul_fp8_quant.py | 244 ++++++++++++++++++ ..._mul_per_token_group_quant_fp8_colmajor.py | 86 ++++++ .../layers/fused_moe/deep_gemm_moe.py | 114 +++----- .../layers/quantization/utils/fp8_utils.py | 133 ++++++++++ 4 files changed, 496 insertions(+), 81 deletions(-) create mode 100644 benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py create mode 100644 tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py diff --git a/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py new file mode 100644 index 000000000000..04921dafbdbe --- /dev/null +++ b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from enum import Enum +from itertools import product +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +from .utils import ArgPool, Bench, CudaGraphBenchParams + +GROUP_SIZE = 128 +FLOAT8_T = torch.float8_e4m3fn + + +def print_timers(timers: list[TMeasurement], cuda_graph_nops: int): + print( + f"Note : The timings reported above is for {cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {cuda_graph_nops} for single invocation " + "timings." + ) + compare = TBenchmark.Compare(timers) + compare.print() + + +class ImplType(Enum): + SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1 + REFERENCE = 2 + + def get_impl(self): + if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return silu_mul_per_token_group_quant_fp8_colmajor + elif self == ImplType.REFERENCE: + return reference + raise ValueError(f"Unrecognized ImplType {self}") + + +@dataclass +class BenchmarkTensors: + input: torch.Tensor + output: torch.Tensor + + # Reference act output tensor + ref_act_out: torch.Tensor + ref_quant_out: torch.Tensor + + @staticmethod + def make(T: int, N: int) -> "BenchmarkTensors": + assert T % GROUP_SIZE == 0 + assert N % (GROUP_SIZE * 2) == 0 + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + # silu_mul_per_token_group_quant_fp8_colmajor output. + output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to( + FLOAT8_T + ) + + # reference output. + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + ref_quant_out = torch.empty( + (T, N // 2), dtype=torch.bfloat16, device="cuda" + ).to(FLOAT8_T) + + return BenchmarkTensors( + input=input, + output=output, + ref_act_out=ref_act_out, + ref_quant_out=ref_quant_out, + ) + + @property + def T(self): + return self.input.size(0) + + @property + def N(self): + return self.input.size(1) + + def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]: + if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return { + "input": self.input, + "output": self.output, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + elif impl_type == ImplType.REFERENCE: + return { + "input": self.input, + "act_out": self.ref_act_out, + "quant_out": self.ref_quant_out, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + raise ValueError(f"Unrecognized impl_type {impl_type}") + + +def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + assert quant_out.size() == x.size() + # Allocate the scale tensor column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_q = quant_out + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_T) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference( + input: torch.Tensor, + act_out: torch.Tensor, + quant_out: torch.Tensor, + use_ue8m0: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops._C.silu_and_mul(act_out, input) + return reference_quant(act_out, quant_out, use_ue8m0) + + +def bench_impl( + bench_tensors: list[BenchmarkTensors], impl_type: ImplType +) -> TMeasurement: + T = bench_tensors[0].T + N = bench_tensors[0].N + + arg_pool_size = len(bench_tensors) + kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors] + + # warmup + for kwargs in kwargs_list: + impl_type.get_impl()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + cuda_graph_params = None + cuda_graph_params = CudaGraphBenchParams(arg_pool_size) + timer = None + with Bench( + cuda_graph_params, + "silu-mul-quant", + f"num_tokens={T}, N={N}", + impl_type.name, + impl_type.get_impl(), + **kwargs, + ) as bench: + timer = bench.run() + return timer + + +def test_correctness(T: int, N: int): + print(f"Testing num_tokens={T}, N={N} ...") + + bench_tensor = BenchmarkTensors.make(T, N) + + def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]: + return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl)) + + # reference output + ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE) + + # test ouptut + out_q, out_s = output_from_impl( + ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + + torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32)) + torch.testing.assert_close(ref_out_s, out_s) + + +def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]: + timers = [] + for N, T in product(Ns, Ts): + test_correctness(T, N) + + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(T, N) for _ in range(arg_pool_size) + ] + + silu_mul_quant_timer = bench_impl( + bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + timers.append(silu_mul_quant_timer) + reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE) + timers.append(reference_timer) + + print_timers( + [silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size + ) + + print_timers(timers, cuda_graph_nops=arg_pool_size) + + return timers + + +if __name__ == "__main__": + T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)] + N = [2048, 4096, 8192] + + print(f"T = {T}, N = {N}") + run(T, N, arg_pool_size=8) diff --git a/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py b/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py new file mode 100644 index 000000000000..e4617072cd52 --- /dev/null +++ b/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +FLOAT8_DTYPE = torch.float8_e4m3fn +GROUP_SIZE = 128 + + +def reference_quant(x: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + + x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE) + + # Allocate the scale tensor in column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_DTYPE) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: + T, N = x.size() + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + torch.ops._C.silu_and_mul(ref_act_out, x) + return reference_quant(ref_act_out, use_ue8m0) + + +@pytest.mark.parametrize("T", [128, 256, 512]) +@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2]) +def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int): + current_platform.seed_everything(42) + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + use_ue8m0 = is_deep_gemm_e8m0_used() + + # Test + output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor( + input, use_ue8m0=use_ue8m0 + ) + + # Reference + ref_output, ref_output_scales = reference(input, use_ue8m0) + + torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32)) + torch.testing.assert_close(output_scales, ref_output_scales) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 86cdd25f2c87..9f47e692d5ae 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -from tqdm import tqdm -import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( @@ -25,12 +23,12 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, + silu_mul_per_token_group_quant_fp8_colmajor, ) from vllm.utils.deep_gemm import ( get_mk_alignment_for_contiguous_layout, m_grouped_fp8_gemm_nt_contiguous, ) -from vllm.utils.func_utils import run_once from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -108,70 +106,6 @@ def _valid_deep_gemm( return True -@run_once -def warmup_deepgemm_gg_contiguous_kernels( - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int, -): - """ - DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the - input tensor shapes. In this function, we construct all possible input - tensor shapes so all the kernels are JIT'ed and cached. - Note that this warmup is expected to happen during the model profile - call and not during actual model inference. - """ - - assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - - block_m = get_mk_alignment_for_contiguous_layout()[0] - num_experts = w1.size(0) - device = w1.device - - # This is the maximum GroupedGemm M size that we expect to run - # the grouped_gemm with. - MAX_M = compute_aligned_M( - env.VLLM_FUSED_MOE_CHUNK_SIZE, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None, - ) - # Distribute expert-ids evenly. - MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint( - low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 - ) - expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) - - def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() - a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) - a1q_scales = torch.empty( - (MAX_M, k // block_m), device=device, dtype=torch.float32 - ) - out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) - - pbar = tqdm( - total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})" - ) - num_tokens = MAX_M - while num_tokens > 0: - m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), - (w, w_scale), - out[:num_tokens], - expert_ids[:num_tokens], - ) - pbar.update(1) - num_tokens = num_tokens - block_m - - _warmup(w1, w1_scale) - _warmup(w2, w2_scale) - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) @@ -215,11 +149,32 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) assert M_sum % block_m == 0 - workspace1 = (M_sum, N) - workspace2 = (M_sum, max(N // 2, K)) + workspace1 = (M_sum, max(N // 2, K)) + workspace2 = (M_sum, max(N, K)) output = (M, K) return (workspace1, workspace2, output) + def _act_mul_quant( + self, input: torch.Tensor, output: torch.Tensor, activation: str + ) -> tuple[torch.Tensor, torch.Tensor]: + if activation == "silu": + return silu_mul_per_token_group_quant_fp8_colmajor( + input=input, output=output + ) + else: + # This is a fallback path. If we find ourselves using any activation other + # than silu, we should add that activation to + # silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster. + M_sum, N = input.size() + act_out = torch.empty( + (M_sum, N // 2), dtype=input.dtype, device=input.device + ) + self.activation(activation, act_out, input) + assert self.block_shape is not None + return per_token_group_quant_fp8( + act_out, self.block_shape[1], column_major_scales=True, out_q=output + ) + def apply( self, output: torch.Tensor, @@ -261,14 +216,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta=expert_tokens_meta, ) - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K)) - mm1_out = _resize_cache(workspace13, (M_sum, N)) - act_out = _resize_cache(workspace2, (M_sum, N // 2)) - quant_out = _resize_cache( - workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + a1q_perm = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K) ) - mm2_out = _resize_cache(workspace2, (M_sum, K)) - a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( aq=a1q, aq_scale=a1q_scale, @@ -280,17 +230,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) assert a1q.size(0) == M_sum + mm1_out = _resize_cache(workspace2, (M_sum, N)) m_grouped_fp8_gemm_nt_contiguous( (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids ) - self.activation(activation, act_out, mm1_out.view(-1, N)) - - a2q_scale: torch.Tensor | None = None - a2q, a2q_scale = per_token_group_quant_fp8( - act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + ) + a2q, a2q_scale = self._act_mul_quant( + input=mm1_out.view(-1, N), output=quant_out, activation=activation ) + mm2_out = _resize_cache(workspace2, (M_sum, K)) m_grouped_fp8_gemm_nt_contiguous( (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ae63b4a76726..6e73833d1ae1 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -492,6 +492,139 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +@triton.jit +def _silu_mul_per_token_group_quant_fp8_colmajor( + y_ptr, # [M, N] + y_q_ptr, # [M, N // 2] + y_s_ptr, # [M, (N // 2) // GROUP_SIZE] + M, # num tokens + N, # intermediate size + # Stride + y_s_col_stride: tl.int64, + # Information for float8 + eps, + fp8_min, + fp8_max, + use_ue8m0: tl.constexpr, + # Meta-parameters + GROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks. + """ + Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then + the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills + the outputs tensors at the right positions. + """ + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + N_2 = N // 2 + + m_offset = pid_m * BLOCK_M + n_offset = pid_n * BLOCK_N + if m_offset >= M: + return + + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_m = tl.arange(0, BLOCK_M).to(tl.int64) + + base_y_ptr = y_ptr + m_offset * N + n_offset + + act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :] + + act_in = tl.load(act_in_ptrs) + mul_in = tl.load(act_in_ptrs + N_2) + + # silu & mul + act_in = act_in.to(tl.float32) + one_f32 = tl.cast(1, tl.float32) + silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty) + y = (silu_out * mul_in).to(tl.float32) + + # quant + _absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps) + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw + y_s = tl.reshape(y_s, (BLOCK_M, 1)) + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + # store y_q + base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset + y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :] + tl.store(y_q_ptrs, y_q) + + # store y_s + group_id = n_offset // GROUP_SIZE + base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset + y_s_ptrs = base_y_s_ptr + offs_m + y_s = tl.reshape(y_s, (BLOCK_M,)) + tl.store(y_s_ptrs, y_s) + + +def silu_mul_per_token_group_quant_fp8_colmajor( + input: torch.Tensor, # [M, N] + output: torch.Tensor | None = None, # [M, N // 2] + use_ue8m0: bool | None = None, + eps: float = 1e-10, +): + """ + silu+mul + block-fp8 quant with group size 128. + """ + GROUP_SIZE = 128 + assert input.ndim == 2 + if output is not None: + assert output.ndim == 2 + assert input.size(0) % GROUP_SIZE == 0 + assert input.size(1) % (GROUP_SIZE * 2) == 0 + + if use_ue8m0 is None: + use_ue8m0 = is_deep_gemm_e8m0_used() + + M, N = input.size() + N_2 = N // 2 + + if output is None: + output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device) + + output_scales = torch.empty( + ((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device + ).transpose(0, 1) + + BLOCK_M = 8 + BLOCK_N = GROUP_SIZE + assert M % BLOCK_M == 0 + assert N_2 % BLOCK_N == 0 + + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_min = finfo.min + fp8_max = finfo.max + + # Force even division so we can avoid edgecases within the kernel. + assert M % BLOCK_M == 0 + assert N_2 % BLOCK_N == 0 + grid = (M // BLOCK_M, N_2 // BLOCK_N) + + _silu_mul_per_token_group_quant_fp8_colmajor[grid]( + input, + output, + output_scales, + M, + N, + output_scales.stride(-1), + eps, + fp8_min, + fp8_max, + use_ue8m0, + GROUP_SIZE, + BLOCK_M, + BLOCK_N, + ) + + return output, output_scales + + @triton.jit def _per_token_group_quant_fp8_colmajor( # Pointers to inputs and output