mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:17:16 +08:00
[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations (#10867)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
28566d73b3
commit
460a2b1100
@ -241,6 +241,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
|
||||
@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename A, typename B>
|
||||
static inline constexpr auto div_ceil(A a, B b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
// Round a down to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_previous_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : (a / b) * b;
|
||||
}
|
||||
|
||||
// Round a up to the next multiple of b. The caller is responsible for making
|
||||
// sure that b is non-zero
|
||||
template <typename T>
|
||||
inline constexpr T round_to_next_multiple_of(T a, T b) {
|
||||
return a % b == 0 ? a : ((a / b) + 1) * b;
|
||||
}
|
||||
|
||||
@ -97,6 +97,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
120
csrc/quantization/activation_kernels.cu
Normal file
120
csrc/quantization/activation_kernels.cu
Normal file
@ -0,0 +1,120 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
#include "core/math.hpp"
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T)(((float)x) / (1.0f + expf((float)-x)));
|
||||
}
|
||||
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
typename fp8_type>
|
||||
__global__ void act_and_mul_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const float* scale, const int d) {
|
||||
const int32_t blocks_per_token = gridDim.y;
|
||||
|
||||
const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);
|
||||
|
||||
// We don't expect the hidden dimension to exceed 32 bits so int32 should
|
||||
// be safe here.
|
||||
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
|
||||
const int32_t elems_per_block =
|
||||
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
|
||||
const int32_t block_start = blockIdx.y * elems_per_block;
|
||||
int32_t block_end = block_start + elems_per_block;
|
||||
block_end = block_end > d ? d : block_end;
|
||||
|
||||
// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
|
||||
// is very large
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
|
||||
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
|
||||
fp8_type* __restrict__ out_ptr = out + token_idx * d;
|
||||
|
||||
// 128-bit vectorized code
|
||||
const int32_t vec_loop_end =
|
||||
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
|
||||
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
|
||||
const int32_t vec_start_idx = block_start / elems_per_128bit_load;
|
||||
|
||||
const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
|
||||
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
|
||||
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);
|
||||
|
||||
float inverted_scale = 1 / *scale;
|
||||
#pragma unroll
|
||||
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
|
||||
vec_idx += blockDim.x) {
|
||||
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
|
||||
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
|
||||
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
|
||||
using scalar_64bit_vec_t = std::array<fp8_type, elems_per_128bit_load>;
|
||||
|
||||
scalar_64bit_vec_t out_vec;
|
||||
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
|
||||
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < elems_per_128bit_load; i++) {
|
||||
out_vec[i] = scaled_fp8_conversion<true, fp8_type>(
|
||||
ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
|
||||
}
|
||||
|
||||
out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
|
||||
}
|
||||
|
||||
// Scalar cleanup code
|
||||
if (block_end > vec_loop_end) {
|
||||
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
|
||||
idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
|
||||
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
|
||||
out_ptr[idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(ACT_FN(x) * y, inverted_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation, gating, and quantize kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
|
||||
dim3 block(std::min(d, 512)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
||||
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>, \
|
||||
fp8_t> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<fp8_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
scale.data_ptr<float>(), d); \
|
||||
}); \
|
||||
});
|
||||
|
||||
void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
torch::Tensor& scale) {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
|
||||
input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
@ -81,9 +81,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Activation ops
|
||||
// Activation function used in SwiGLU.
|
||||
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
|
||||
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
|
||||
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
|
||||
|
||||
ops.def(
|
||||
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
|
||||
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
|
||||
|
||||
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
@ -17,7 +18,6 @@ from .backend import TestBackend
|
||||
OPS_IN_MODEL = [
|
||||
torch.ops._C.rotary_embedding.default,
|
||||
torch.ops._C.fused_add_rms_norm.default,
|
||||
torch.ops._C.silu_and_mul.default,
|
||||
]
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
@ -29,6 +29,9 @@ RMS_QUANT_OPS = {
|
||||
],
|
||||
}
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -55,8 +58,10 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
enable_noop=True))
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = FusionPass.instance(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
|
||||
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
|
||||
] if do_fusion else [noop_pass]
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
@ -79,6 +84,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
model_runner.model = torch.compile(orig_model,
|
||||
fullgraph=True,
|
||||
backend=backend_no_func)
|
||||
|
||||
gen_no_func = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output_func, output_no_func in zip(gen_func, gen_no_func):
|
||||
@ -88,7 +94,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
||||
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
|
||||
] if do_fusion else [RMS_OP]
|
||||
ops = OPS_IN_MODEL + rms_ops
|
||||
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
|
||||
quant_key == kFp8StaticTensorSym else [
|
||||
SILU_MUL_OP
|
||||
]
|
||||
|
||||
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
|
||||
|
||||
for op in ops:
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
|
||||
74
tests/compile/test_silu_mul_quant_fusion.py
Normal file
74
tests/compile/test_silu_mul_quant_fusion.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = scaled_fp8_quant(y, self.scale)
|
||||
return x2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
config = VllmConfig()
|
||||
config.compilation_config = CompilationConfig(
|
||||
pass_config=CompilationConfig.PassConfig(enable_fusion=True,
|
||||
enable_reshape=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(fusion_pass)
|
||||
model = TestModel()
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
|
||||
# Check that it gives the same answer
|
||||
torch.testing.assert_close(result[0].to(dtype=torch.float16),
|
||||
result2[0].to(dtype=torch.float16),
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend.graph_pre_pass.nodes
|
||||
post_nodes = backend.graph_post_pass.nodes
|
||||
|
||||
silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
|
||||
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
|
||||
|
||||
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
||||
assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
|
||||
find_auto_fn(pre_nodes, fp8_quant)
|
||||
|
||||
# In post-nodes, fused kernels should be present and fp8 quant should not
|
||||
find_auto_fn(post_nodes, silu_and_mul_quant)
|
||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
||||
69
tests/kernels/test_fused_quant_activation.py
Normal file
69
tests/kernels/test_fused_quant_activation.py
Normal file
@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
DTYPES = [torch.bfloat16, torch.float16]
|
||||
QUANT_DTYPES = [torch.float8_e4m3fn]
|
||||
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
|
||||
scale: torch.Tensor) -> torch.Tensor:
|
||||
silu_and_mul_out = silu_and_mul.forward_native(x)
|
||||
out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale)
|
||||
return out
|
||||
|
||||
|
||||
def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
out_shape = (x.shape[0], x.shape[1] // 2)
|
||||
out = torch.empty(out_shape,
|
||||
dtype=torch.torch.float8_e4m3fn,
|
||||
device=x.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, x, scale)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
layer = SiluAndMul()
|
||||
|
||||
# Make inputs
|
||||
scale = (torch.randn((1), device=device, dtype=torch.float32))
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
|
||||
ref_out = ref_impl(layer, x, scale)
|
||||
ops_out = ops_impl(x, scale)
|
||||
|
||||
assert ref_out.dtype == quant_dtype
|
||||
assert ops_out.dtype == quant_dtype
|
||||
assert ref_out.shape == ops_out.shape
|
||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale))
|
||||
87
vllm/compilation/activation_quant_fusion.py
Normal file
87
vllm/compilation/activation_quant_fusion.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only,
|
||||
register_replacement)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def silu_mul_pattern_static(result: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor, input: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(torch.ops._C.silu_and_mul.default,
|
||||
result=result_silu_mul,
|
||||
input=input)
|
||||
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
return at2[1]
|
||||
|
||||
|
||||
def silu_mul_replacement_static(result: torch.Tensor,
|
||||
result_silu_mul: torch.Tensor,
|
||||
input: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default,
|
||||
result=result,
|
||||
input=input,
|
||||
scale=scale)
|
||||
return at[1]
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp8(*args, **kwargs):
|
||||
fp8 = torch.float8_e4m3fn
|
||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmInductorPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass")
|
||||
|
||||
inputs = [
|
||||
empty_fp8(5, 4), # Quant output
|
||||
empty_bf16(5, 4), # Silu_and_mul output
|
||||
empty_bf16(5, 4), # Input
|
||||
empty_fp32(1, 1) # Scale
|
||||
]
|
||||
register_replacement(silu_mul_pattern_static,
|
||||
silu_mul_replacement_static, inputs, fwd_only,
|
||||
self.patterns)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before_act_quant_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns in ActivationQuantFusionPass",
|
||||
count)
|
||||
|
||||
self.dump_graph(graph, "after_act_quant_fusion")
|
||||
self.end_and_log()
|
||||
@ -68,18 +68,25 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: 'out'}
|
||||
# Because we have an 'out', need to specify args directly
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('out', 'input'))
|
||||
args=('result', 'input'))
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: 'result'}
|
||||
self.defunctionalize(graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=('result', 'input', 'scale'))
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from torch import fx as fx
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||
from .fix_functionalization import FixFunctionalizationPass
|
||||
from .fusion import FusionPass
|
||||
from .inductor_pass import CustomGraphPass, InductorPass, get_pass_context
|
||||
@ -51,6 +52,7 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [FusionPass.instance(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user