[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations (#10867)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-05-01 07:59:28 -07:00 committed by GitHub
parent 28566d73b3
commit 460a2b1100
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 406 additions and 9 deletions

View File

@ -241,6 +241,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/custom_all_reduce.cu"

View File

@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b) {
return a % b == 0 ? a : (a / b) * b;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b) {
return a % b == 0 ? a : ((a / b) + 1) * b;
}

View File

@ -97,6 +97,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

View File

@ -0,0 +1,120 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"
namespace vllm {
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
typename fp8_type>
__global__ void act_and_mul_quant_kernel(
fp8_type* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const float* scale, const int d) {
const int32_t blocks_per_token = gridDim.y;
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
const int32_t block_start = blockIdx.y * elems_per_block;
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const int64_t token_idx = blockIdx.x;
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
fp8_type* __restrict__ out_ptr = out + token_idx * d;
// 128-bit vectorized code
const int32_t vec_loop_end =
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
float inverted_scale = 1 / *scale;
#pragma unroll
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
vec_idx += blockDim.x) {
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
scalar_64bit_vec_t out_vec;
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
#pragma unroll
for (int i = 0; i < elems_per_128bit_load; i++) {
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
}
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
}
// Scalar cleanup code
if (block_end > vec_loop_end) {
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] =
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
}
}
}
} // namespace vllm
// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
VLLM_DISPATCH_FP8_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
fp8_t> \
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
}); \
});
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}

View File

@ -81,9 +81,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
ops.def(
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

View File

@ -5,6 +5,7 @@ import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
@ -17,7 +18,6 @@ from .backend import TestBackend
OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
]
RMS_OP = torch.ops._C.rms_norm.default
@ -29,6 +29,9 @@ RMS_QUANT_OPS = {
],
}
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -55,8 +58,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
@ -79,6 +84,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)
for output_func, output_no_func in zip(gen_func, gen_no_func):
@ -88,7 +94,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
ops = OPS_IN_MODEL + rms_ops
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.config import CompilationConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
self.scale = torch.rand(1, dtype=torch.float32)
def forward(self, x):
y = self.silu_and_mul(x)
x2 = scaled_fp8_quant(y, self.scale)
return x2
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True))
fusion_pass = ActivationQuantFusionPass(config)
backend = TestBackend(fusion_pass)
model = TestModel()
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
torch.testing.assert_close(result[0].to(dtype=torch.float16),
result2[0].to(dtype=torch.float16),
atol=1e-3,
rtol=1e-3)
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, silu_and_mul_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

View File

@ -0,0 +1,69 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul
DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [torch.float8_e4m3fn]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
silu_and_mul_out = silu_and_mul.forward_native(x)
out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale)
return out
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 2)
out = torch.empty(out_shape,
dtype=torch.torch.float8_e4m3fn,
device=x.device)
torch.ops._C.silu_and_mul_quant(out, x, scale)
return out
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
layer = SiluAndMul()
# Make inputs
scale = (torch.randn((1), device=device, dtype=torch.float32))
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
ref_out = ref_impl(layer, x, scale)
ops_out = ops_impl(x, scale)
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert ref_out.shape == ops_out.shape
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale))

View File

@ -0,0 +1,87 @@
# SPDX-License-Identifier: Apache-2.0
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
register_replacement)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def silu_mul_pattern_static(result: torch.Tensor,
result_silu_mul: torch.Tensor, input: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
result=result_silu_mul,
input=input)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
return at2[1]
def silu_mul_replacement_static(result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
result=result,
input=input,
scale=scale)
return at[1]
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp8(*args, **kwargs):
fp8 = torch.float8_e4m3fn
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
class ActivationQuantFusionPass(VllmInductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
def __init__(self, config: VllmConfig):
super().__init__(config)
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="activation_quant_fusion_pass")
inputs = [
empty_fp8(5, 4), # Quant output
empty_bf16(5, 4), # Silu_and_mul output
empty_bf16(5, 4), # Input
empty_fp32(1, 1) # Scale
]
register_replacement(silu_mul_pattern_static,
silu_mul_replacement_static, inputs, fwd_only,
self.patterns)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.dump_graph(graph, "before_act_quant_fusion")
count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
count)
self.dump_graph(graph, "after_act_quant_fusion")
self.end_and_log()

View File

@ -68,18 +68,25 @@ class FixFunctionalizationPass(VllmInductorPass):
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default
torch.ops._C.rms_norm_static_fp8_quant.default,
]:
mutated_args = {1: 'result'}
self.defunctionalize(graph, node, mutated_args)
# For some reason we need to specify the args for both
# silu_and_mul and silu_and_mul_quant. The kwargs
# pathway gets the wrong answer.
elif at_target == torch.ops._C.silu_and_mul.default:
mutated_args = {1: 'out'}
# Because we have an 'out', need to specify args directly
mutated_args = {1: 'result'}
self.defunctionalize(graph,
node,
mutated_args,
args=('out', 'input'))
args=('result', 'input'))
elif at_target == torch.ops._C.silu_and_mul_quant.default:
mutated_args = {1: 'result'}
self.defunctionalize(graph,
node,
mutated_args,
args=('result', 'input', 'scale'))
else:
continue # skip the count

View File

@ -7,6 +7,7 @@ from torch import fx as fx
from vllm.config import VllmConfig
from vllm.logger import init_logger
from .activation_quant_fusion import ActivationQuantFusionPass
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
@ -51,6 +52,7 @@ class PostGradPassManager(CustomGraphPass):
if self.pass_config.enable_fusion:
self.passes += [FusionPass.instance(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]