Compare commits

...

15 Commits

Author SHA1 Message Date
Fanli Lin
9420d6768e
Merge 6ab025c10abed1dc5427fd49c34f6a45cb6c6235 into 3125d7995020f4ae1b7777e41af6c3a6e714cb1c 2025-10-17 19:05:44 +00:00
Isotr0py
3125d79950
[Chore] Remove unused PolyNorm layer (#27110)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
2025-10-17 19:03:43 +00:00
vllmellm
e33ee23ee3
[Bugfix] [AITER] [ROCm] Fix Quark MoE Quant Config and AITER Fused MoE quant type logic (#27029)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
2025-10-17 12:51:10 -06:00
rasmith
b10c64c834
[ROCm][Bugfix][Model] Fix illegal memory access when running qwen3_moe models with rms_norm (Qwen3-235B-A22B, Qwen3-30B-A3B, etc.) (#26192)
Signed-off-by: Randall Smith <ransmith@amd.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: rasmith <Randall.Smith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-10-17 14:17:18 -04:00
Aleksandr Malyshev
0925b28a8e
[ROCM] MoE fp4 CK kernel (#26545)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
2025-10-17 14:06:33 -04:00
Nicolò Lucchesi
99722d5f0e
[CI] Remove forbidden slash (#27112)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-17 09:38:00 -07:00
4c91a28e30
[bugfix] Qwen3-VL fix video incorrect timestamp calculations while do_sample_frames=True (#27104)
Co-authored-by: 松灵 <wpf272043@alibaba-inc.com>
2025-10-17 16:26:33 +00:00
Patrick von Platen
b038d9c40c
[Data-parallel] Allow DP>1 for world_size > num_gpus on node (8) (#26367)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Rui Qiao <ruisearch42@gmail.com>
2025-10-17 08:24:42 -07:00
Nicolò Lucchesi
2ba60ec7fe
[CI] Nixl integration tests (#27010)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-17 07:13:31 -07:00
Luka Govedič
bd7157a071
[torch.compile] Enable attention and allreduce fusion without custom ops enabled (#24604)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2025-10-17 08:10:23 -06:00
Yongtao Huang
be429d0cfd
Fix incorrect docstring for stop_profile() method (#27101)
Signed-off-by: Yongtao Huang <yongtaoh2022@gmail.com>
2025-10-17 06:30:23 -07:00
Reima Karhila (AMD)
c253745eb8
[Harware][AMD][Model] Triton MoE tuning configs for GLM-4.5 for MI350 and MI355 (#25586)
Signed-off-by: Reima Karhila <reima.karhila@amd.com>
Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com>
Co-authored-by: xaguilar <Xavier.AguilarFruto@amd.com>
2025-10-17 04:56:12 -07:00
Jee Jee Li
daec4d2624
[Model]Improve Qwen3VLMoeForConditionalGeneration packed_modules_mapping (#27096)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
2025-10-17 04:47:00 -07:00
Fanli Lin
6ab025c10a
Merge branch 'main' into fix_ut 2025-10-17 17:05:59 +08:00
Fanli Lin
0700fdddc7 use FP1 instead
Signed-off-by: Fanli Lin <fanli.lin@intel.com>
2025-10-17 01:40:45 -07:00
50 changed files with 2300 additions and 1315 deletions

View File

@ -416,8 +416,8 @@ steps:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s compile/piecewise/
- label: PyTorch Fullgraph Test # 20min
timeout_in_minutes: 30
- label: PyTorch Fullgraph Test # 22min
timeout_in_minutes: 35
mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:
@ -425,6 +425,7 @@ steps:
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
- pytest -v -s compile/test_fusions_e2e.py
- label: Kernels Core Operation Test # 48min
timeout_in_minutes: 75
@ -807,8 +808,8 @@ steps:
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
- label: Blackwell Test # 38 min
timeout_in_minutes: 60
- label: Blackwell Test # 21 min
timeout_in_minutes: 30
working_dir: "/vllm-workspace/"
gpu: b200
# optional: true
@ -821,8 +822,6 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
@ -839,15 +838,32 @@ steps:
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- label: Blackwell Fusion Tests # 30 min
timeout_in_minutes: 40
working_dir: "/vllm-workspace/"
gpu: b200
source_file_dependencies:
- csrc/quantization/fp4/
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/
# can affect pattern matching
- vllm/model_executor/layers/layernorm.py
- vllm/model_executor/layers/activation.py
- vllm/model_executor/layers/quantization/input_quant_fp8.py
commands:
- nvidia-smi
- pytest -v -s tests/compile/test_fusion_attn.py
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
# this runner has 2 GPUs available even though num_gpus=2 is not set
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py
- label: Blackwell GPT-OSS Eval
timeout_in_minutes: 60
@ -1068,6 +1084,17 @@ steps:
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
- label: NixlConnector PD accuracy tests (Distributed) # 30min
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
- tests/v1/kv_connector/nixl_integration/
commands:
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
- bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
##### multi gpus test #####
@ -1100,7 +1127,7 @@ steps:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
##### H200 test #####
- label: Distrubted Tests (H200) # optional
- label: Distributed Tests (H200) # optional
gpu: h200
optional: true
working_dir: "/vllm-workspace/"
@ -1108,6 +1135,8 @@ steps:
commands:
- pytest -v -s tests/compile/test_async_tp.py
- pytest -v -s tests/compile/test_sequence_parallelism.py
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
- pytest -v -s tests/distributed/test_context_parallel.py
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048

View File

@ -1,155 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import torch
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
def polynorm_naive(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
def norm(x, eps: float):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
x = x.float()
return (
(
weight[0] * norm(x**3, eps)
+ weight[1] * norm(x**2, eps)
+ weight[2] * norm(x, eps)
+ bias
)
.to(weight.dtype)
.view(orig_shape)
)
def polynorm_vllm(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-6,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
out = torch.empty_like(x)
vllm_ops.poly_norm(out, x, weight, bias, eps)
output = out
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_dim):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
output_naive = polynorm_naive(x, weight, bias)
output_vllm = polynorm_vllm(x, weight, bias)
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
dim_range = [2048, 4096]
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
def get_benchmark():
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["dim", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "vllm"],
line_names=["Naive", "vLLM"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
plot_name="polynorm-perf",
args={},
)
)
def benchmark(dim, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_dim = dim * 4
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
weight = torch.ones(3, dtype=dtype, device="cuda")
bias = torch.ones(1, dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_naive(x, weight, bias),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: polynorm_vllm(x, weight, bias),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size",
)
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length",
)
parser.add_argument(
"--hidden-dim",
type=int,
default=8192,
help="Intermediate size of MLP",
)
parser.add_argument(
"--save-path",
type=str,
default="./configs/polnorm/",
help="Path to save polnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=args.batch_size,
seq_len=args.seq_len,
hidden_dim=args.hidden_dim,
)
benchmark = get_benchmark()
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)

View File

@ -148,211 +148,6 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template <typename scalar_t, int width>
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
using Base = _f16Vec<scalar_t, width>;
using Converter = typename Base::Converter;
using T1 = typename Base::T1;
using T2 = typename Base::T2;
using Base::data;
__device__ auto sum_pows() const {
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x4 = x2 * x2;
float x6 = x4 * x2;
float y2 = z.y * z.y;
float y4 = y2 * y2;
float y6 = y4 * y2;
s2 += x2 + y2;
s4 += x4 + y4;
s6 += x6 + y6;
}
return std::make_tuple(s2, s4, s6);
}
__device__ void poly_norm_inplace(const float w2_inv_std,
const float w1_inv_std2,
const float w0_inv_std3, const float bias) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float x2 = z.x * z.x;
float x3 = x2 * z.x;
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
float y2 = z.y * z.y;
float y3 = y2 * z.y;
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
auto out = Converter::convert(z);
data[i] = out.x;
data[i + 1] = out.y;
}
}
};
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
const int vec_hidden_size = hidden_size / width;
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
auto [x2, x4, x6] = temp.sum_pows();
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16VecPN<scalar_t, width> temp = input_v[id];
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
out_v[id] = temp;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [3]
const scalar_t* __restrict__ bias, // [1]
const float epsilon, const int hidden_size) {
float variance = 0.0f;
float variance2 = 0.0f;
float variance3 = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x4 = x2 * x2;
float x6 = x4 * x2;
variance += x2;
variance2 += x4;
variance3 += x6;
}
float3 thread_variances = make_float3(variance, variance2, variance3);
struct SumOp {
__device__ float3 operator()(const float3& a, const float3& b) const {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
};
using BlockReduce = cub::BlockReduce<float3, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
float3 block_variances =
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
variance = block_variances.x;
variance2 = block_variances.y;
variance3 = block_variances.z;
__shared__ float s_w2_inv_std;
__shared__ float s_w1_inv_std2;
__shared__ float s_w0_inv_std3;
__shared__ float s_bias;
if (threadIdx.x == 0) {
float w0 = (float)weight[0];
float w1 = (float)weight[1];
float w2 = (float)weight[2];
s_bias = (float)bias[0];
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float x2 = x * x;
float x3 = x2 * x;
out[blockIdx.x * hidden_size + idx] =
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
s_bias);
}
}
} // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size]
@ -364,18 +159,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
int64_t input_stride = input.stride(-2);
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
// Instead, we use a 2d view to get the second-innermost stride.
// That way the dimensions (except the last one) can be arbitrarily permuted.
torch::Tensor input_view = input.view({-1, hidden_size});
int num_tokens = input_view.numel() / hidden_size;
int64_t input_stride = input_view.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
VLLM_DISPATCH_FLOATING_TYPES(
input_view.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
@ -392,6 +195,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
@ -434,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void poly_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [3]
torch::Tensor& bias, // [1]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.data_ptr() != input.data_ptr());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
LAUNCH_FUSED_POLY_NORM(8);
} else {
LAUNCH_FUSED_POLY_NORM(0);
}
}

View File

@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(residual.is_contiguous());
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
int hidden_size = input.size(-1);
int input_stride = input.stride(-2);
int num_tokens = input.numel() / hidden_size;

View File

@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& bias, double epsilon);
void apply_repetition_penalties_(torch::Tensor& logits,
const torch::Tensor& prompt_mask,
const torch::Tensor& output_mask,

View File

@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(weight.dtype() == input.dtype());
TORCH_CHECK(scales.dtype() == torch::kFloat32);
if (residual) {
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
}
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {

View File

@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Polynomial Normalization.
ops.def(
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
"epsilon) -> ()");
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
// Apply repetition penalties to logits in-place
ops.def(
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "

View File

@ -69,6 +69,7 @@ There are several notable differences when using Ray:
- A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node
- There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address`
- There is no need to specify `--data-parallel-rpc-port`
- When a single DP group requires multiple nodes, *e.g.* in case a single model replica needs to run on at least two nodes, make sure to set `VLLM_RAY_DP_PACK_STRATEGY="span"` in which case `--data-parallel-size-local` is ignored and will be automatically determined
- Remote DP ranks will be allocated based on node resources of the Ray cluster
Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic.

View File

@ -3,16 +3,22 @@
import weakref
from collections.abc import Callable, Sequence
from contextlib import nullcontext
from copy import deepcopy
import depyf
from torch import fx
from torch._ops import OpOverload
from torch.fx._utils import lazy_format_graph_code
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
logger = init_logger("vllm.tests.compile.backend")
class LazyInitPass(InductorPass):
@ -45,20 +51,32 @@ class TestBackend:
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
self.custom_passes = list(passes)
compile_config = get_current_vllm_config().compilation_config
self.inductor_config = compile_config.inductor_compile_config
vllm_config = get_current_vllm_config()
compile_config = vllm_config.compilation_config
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
self.inductor_config["force_disable_caches"] = True
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
if debug_dump_path := vllm_config.compile_debug_dump_path():
logger.debug("Dumping depyf output to %s", debug_dump_path)
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
else:
self.debug_ctx = nullcontext()
def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
with self.debug_ctx:
return compile_fx(
graph, example_inputs, config_patches=self.inductor_config
)
@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
VllmInductorPass.dump_prefix = 0
for pass_ in self.custom_passes:
@ -68,6 +86,7 @@ class TestBackend:
VllmInductorPass.dump_prefix = None
self.graph_post_pass = deepcopy(graph)
lazy_format_graph_code("graph_post_pass", graph.owning_module)
# assign by reference, will reflect the final state of the graph
self.final_graph = graph

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import tempfile
from pathlib import Path
from typing import Any
import pytest
@ -10,8 +10,6 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
@ -22,23 +20,24 @@ from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{
"dtype": torch.float16,
},
),
(
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
{
"dtype": torch.float16,
},
{"dtype": torch.float16},
),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}),
]
if all:
TEST_MODELS.extend(
[
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
{"dtype": torch.float16},
),
]
)
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(
@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
"compilation_mode",
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
)
@pytest.mark.parametrize("model_info", models_list(all=True))
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
@create_new_process_for_each_test()
def test_full_graph(
monkeypatch: pytest.MonkeyPatch,
model_info: tuple[str, dict[str, Any]],
model: str,
model_kwargs: dict[str, Any],
compilation_mode: int,
):
model, model_kwargs = model_info
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
with monkeypatch.context():
print(f"MODEL={model}")
run_model(compilation_mode, model, model_kwargs)
run_model(compilation_mode, model, **model_kwargs)
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config, model_info",
"compilation_config, model, model_kwargs",
[
# additional compile sizes, only some of the models
(
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
model,
*model_info,
)
for model in models_list(all=False)
for model_info in models_list(all=False)
]
+ [
# RMSNorm + quant fusion, only 8-bit quant models
@ -117,18 +123,19 @@ def test_full_graph(
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
model,
*model_info,
)
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
]
+ [
# Test depyf integration works
(
CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
debug_dump_path=tempfile.gettempdir(),
debug_dump_path=Path(tempfile.gettempdir()),
),
("facebook/opt-125m", {}),
"facebook/opt-125m",
{},
),
]
+ [
@ -142,9 +149,9 @@ def test_full_graph(
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2],
),
model,
*model_info,
)
for model in models_list(all=False)
for model_info in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
],
)
@ -152,16 +159,24 @@ def test_full_graph(
@create_new_process_for_each_test()
def test_custom_compile_config(
compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
model: str,
model_kwargs: dict[str, Any],
):
if (
"w8a8" in model
or "w8w8" in model
and current_platform.has_device_capability((10, 0))
):
# int8 removed on Blackwell:
pytest.skip("int8 support removed on Blackwell")
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"
):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)
run_model(compilation_config, model, **model_kwargs)
@pytest.mark.parametrize(
@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int):
"calculate_kv_scales": True,
"max_model_len": 512,
}
run_model(compilation_mode, model, model_kwargs)
run_model(compilation_mode, model, **model_kwargs)
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
compile_config
if isinstance(compile_config, CompilationConfig)
else CompilationConfig(level=compile_config)
)
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with (
caplog_vllm.at_level(logging.DEBUG),
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
):
run_model(compilation_config, model, model_kwargs)
try:
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
caplog_vllm.text
)
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text
def run_model(
compile_config: int | CompilationConfig,
model: str,
model_kwargs: dict[str, Any],
):
prompts = [
"Hello, my name is",
"The president of the United States is",
@ -227,12 +208,17 @@ def run_model(
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM(
model=model,
enforce_eager=True,
tensor_parallel_size=1,
disable_custom_all_reduce=True,
compilation_config=compile_config,
compilation_config=compilation_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)

View File

@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
return y
def example_inputs(self, num_tokens=32, hidden_size=128):
dtype = torch.float16 if TEST_FP8 else torch.float32
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
return (torch.rand(num_tokens, hidden_size * 2),)
def ops_in_model(self, do_fusion):
if TEST_FP8 and do_fusion:
@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
dtype = torch.float16 if TEST_FP8 else torch.float32
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size), dtype=dtype)
torch.empty((intermediate_size, hidden_size))
)
self.norm = RMSNorm(intermediate_size, 1e-05)
self.norm.weight = torch.nn.Parameter(
torch.ones(intermediate_size, dtype=dtype)
)
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
torch.nn.init.normal_(self.gate_proj, std=0.02)
@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
return norm_output, residual_output
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
dtype = torch.float16 if TEST_FP8 else torch.float32
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
residual = torch.randn((batch_size * seq_len, hidden_size))
return (hidden_states, residual)
def ops_in_model(self, do_fusion):
@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
return q_rotated, k_rotated
def example_inputs(self, num_tokens=32, head_dim=64):
dtype = torch.float16
positions = torch.arange(num_tokens, dtype=torch.long)
q = torch.randn(num_tokens, head_dim, dtype=dtype)
k = torch.randn(num_tokens, head_dim, dtype=dtype)
q = torch.randn(num_tokens, head_dim)
k = torch.randn(num_tokens, head_dim)
return (positions, q, k)
def ops_in_model(self, do_fusion):
@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
self.hidden_size = head_dim * num_heads
self.qkv_proj = torch.nn.Linear(
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
self.hidden_size, self.hidden_size * 3, bias=False
)
self.rotary_emb = get_rope(
@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
return qkv_updated
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
dtype = torch.float16
hidden_size = head_dim * num_heads
positions = torch.arange(num_tokens, dtype=torch.long)
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
hidden_states = torch.randn(num_tokens, hidden_size)
return (positions, hidden_states)
def ops_in_model(self, do_fusion):
@ -211,48 +209,58 @@ MODELS = [
]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("model_class", MODELS)
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
def test_fix_functionalization(
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
custom_ops=["all"],
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
),
)
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
with set_current_vllm_config(vllm_config):
assert RMSNorm.enabled()
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
passes = (
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
if do_fusion
else [noop_pass, cleanup_pass]
)
func_pass = FixFunctionalizationPass(vllm_config)
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
model = model_class()
torch.compile(model, backend=backend_func)(*model.example_inputs())
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
# check if the functionalization pass is applied
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model(do_fusion):
if is_func(node, op):
found[op] = True
for op in model.ops_not_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model(do_fusion))
assert all(not found.get(op) for op in model.ops_not_in_model())

View File

@ -5,15 +5,18 @@ import pytest
import torch
import vllm.plugins
from vllm.compilation.fusion import (
FUSED_OPS,
QUANT_OPS,
FusedRMSQuantKey,
RMSNormQuantFusionPass,
)
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
VllmConfig,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@ -32,6 +35,9 @@ from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class TestModel(torch.nn.Module):
def __init__(
@ -45,18 +51,18 @@ class TestModel(torch.nn.Module):
):
super().__init__(*args, **kwargs)
self.cuda_force_torch = cuda_force_torch
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
else:
self.scale = [None for _ in range(2)]
self.scale = [None for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
for _ in range(3)
]
with override_cutlass_fp8_supported(not cuda_force_torch):
@ -65,8 +71,12 @@ class TestModel(torch.nn.Module):
act_quant_group_shape=group_shape,
)
self.enable_rms_norm_custom_op = self.norm[0].enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x):
resid = torch.sqrt(x)
# avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x)
y = self.norm[0](x)
x2 = self.fp8_linear.apply(
@ -78,24 +88,44 @@ class TestModel(torch.nn.Module):
x3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
def ops_in_model_before(self):
return [QUANT_OPS[self.key]]
y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
]
def ops_in_model_before(self):
return (
[QUANT_OPS[self.quant_key]]
if self.enable_quant_fp8_custom_op
else [torch.ops.aten.reciprocal]
)
def ops_in_model_before_partial(self):
return (
[RMS_OP, RMS_ADD_OP]
if self.enable_rms_norm_custom_op
else [torch.ops.aten.rsqrt]
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
@ -105,19 +135,32 @@ class TestModel(torch.nn.Module):
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
)
def test_fusion_rmsnorm_quant(
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
dtype,
hidden_size,
num_tokens,
eps,
static,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
cuda_force_torch,
):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+rms_norm", "+quant_fp8"],
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
)
),
)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant(
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
backend2 = TestBackend(noop_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
model_unfused = torch.compile(model, backend=backend2)
result_unfused = model_unfused(x)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
if dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
assert fusion_pass.matched_count == 2
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert fusion_pass.matched_count == 3
backend.check_before_ops(model.ops_in_model_before())
# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_before_ops(
model.ops_in_model_before_partial(), fully_replaced=False
)
backend.check_after_ops(model.ops_in_model_after())
# If RMSNorm custom op is disabled (native/torch impl used),
# there's a risk that the fused add doesn't get included in the
# replacement and only the rms part gets fused with quant.
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
if not enable_rms_norm_custom_op:
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
assert n_add_nodes(backend.graph_pre_pass) == 7
assert n_add_nodes(backend.graph_post_pass) == 2

View File

@ -6,6 +6,7 @@ import pytest
import torch
import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
@ -17,6 +18,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@ -25,8 +27,8 @@ from vllm.distributed.parallel_state import (
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape,
QuantFP8,
)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
@ -40,13 +42,30 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm = self.norm(all_reduce)
return norm
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(x)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = torch.mm(y, self.w[0])
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = torch.mm(y2, self.w[1])
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid)
z4 = torch.mm(y3, self.w[2])
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid)
return y4
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
@ -55,44 +74,53 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm, _ = self.norm(all_reduce, residual)
return norm
def ops_in_model_before(self):
return [torch.ops.vllm.all_reduce.default]
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(
self.output, norm_output.contiguous(), self.scale
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
return self.output, residual_output
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
@ -100,7 +128,9 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
else torch.ops.aten.reciprocal.default,
]
@ -109,25 +139,48 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(
self.output, norm_output, self.output_scale, self.scale
wq_gen, wscale_gen = zip(
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
)
return self.output, residual_output, self.output_scale
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
print(f"{self.wq=}, {self.wscale=}")
def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
z2 = cutlass_scaled_fp4_mm(
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
)
x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
z3 = cutlass_scaled_fp4_mm(
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
)
x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
z4 = cutlass_scaled_fp4_mm(
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
@ -141,19 +194,19 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model",
"test_model, enable_quant_fp8_custom_op",
[
TestAllReduceRMSNormModel,
TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
# TODO: Enable with torch==2.8.0
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
(TestAllReduceRMSNormModel, False),
(TestAllReduceRMSNormStaticQuantFP8Model, True),
(TestAllReduceRMSNormStaticQuantFP8Model, False),
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
],
)
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
not find_spec("flashinfer")
@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
num_processes = 2
if (
@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace(
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(
fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
),
nprocs=nprocs,
)
@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
enable_rms_norm_custom_op,
enable_quant_fp8_custom_op,
):
current_platform.seed_everything(0)
@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model(
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
custom_ops = []
if enable_rms_norm_custom_op:
custom_ops.append("+rms_norm")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
vllm_config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"]
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
)
)
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True
)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model(
vllm_config.model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
backend = TestBackend(
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)
assert all_reduce_fusion_pass.matched_count == 1
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
assert all_reduce_fusion_pass.matched_count == 4, (
f"{all_reduce_fusion_pass.matched_count=}"
)
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass

View File

@ -6,14 +6,15 @@ import pytest
import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend
from tests.utils import flat_product
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.matcher_utils import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (
@ -28,21 +29,18 @@ from vllm.config import (
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# globals needed for string-import custom Dynamo backend field
backend: TestBackend | None = None
backend_unfused: TestBackend | None = None
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
@ -104,6 +102,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_blocks = batch_size * max_blocks
backend = self.attn.backend
# TODO(luka) use get_kv_cache_stride_order
# Create dummy KV cache for the selected backend
if backend == _Backend.ROCM_ATTN:
# k/v as 1st dimention
@ -241,26 +240,40 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
)
MODELS_FP8: list[tuple[str, type]] = []
MODELS_FP4: list[tuple[str, type]] = []
HEADS: list[tuple[int, int]] = []
SPLIT_ATTENTION: list[bool] = []
BACKENDS_FP8: list[_Backend] = []
BACKENDS_FP4: list[_Backend] = []
if current_platform.is_cuda():
MODELS = [
HEADS = [(64, 8), (40, 8)]
MODELS_FP8 = [
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel,
),
)
]
MODELS_FP4 = [
(
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel,
),
)
]
HEADS = [(64, 8), (40, 8)]
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
BACKENDS_FP4 = [_Backend.FLASHINFER]
elif current_platform.is_rocm():
MODELS = [
HEADS = [(32, 8), (40, 8)]
MODELS_FP8 = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
HEADS = []
BACKENDS = [
_Backend.ROCM_AITER_UNIFIED_ATTN,
_Backend.ROCM_ATTN,
_Backend.TRITON_ATTN,
]
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@ -269,46 +282,36 @@ else:
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize(
"backend",
[_Backend.FLASHINFER]
if current_platform.is_cuda()
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
)
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True],
"backend, model_name, model_class, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
# quant_fp4 only has the custom impl
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)",
)
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
)
def test_attention_quant_pattern(
num_qo_heads: int,
num_kv_heads: int,
head_size: int,
batch_size: int,
dtype: torch.dtype,
custom_ops: str,
model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend,
use_inductor_graph_partition: bool,
dist_init,
caplog_vllm,
):
"""Test AttentionStaticQuantPattern fusion pass"""
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
custom_ops_list = custom_ops.split(",") if custom_ops else []
device = torch.device("cuda:0")
torch.manual_seed(42)
@ -322,8 +325,7 @@ def test_attention_quant_pattern(
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
custom_ops=custom_ops_list,
),
cache_config=CacheConfig(cache_dtype="fp8"),
)
@ -358,8 +360,9 @@ def test_attention_quant_pattern(
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v)
# Run model directly without fusion
# Still compile so query QuantFP8 has closer numerics
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
@ -414,16 +417,25 @@ def test_attention_quant_pattern(
)
# Check attn fusion support
quant_key = model_class.quant_key
quant_key: QuantKey = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
# Note: fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
"All layers should support attention fusion"
)
# Check quantization ops in the graph before and after fusion
quant_op = (
torch.ops.aten.reciprocal
if "-quant_fp8" in custom_ops_list
else QUANT_OPS[quant_key]
)
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

View File

@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import itertools
import logging
from collections.abc import Iterable
from typing import Any, NamedTuple
import pytest
import regex as re
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.utils.flashinfer import has_flashinfer
from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
attention_fusions: int
allreduce_fusions: int | None = None
MODELS_FP8: list[ModelBackendTestCase] = []
MODELS_FP4: list[ModelBackendTestCase] = []
MODELS: list[ModelBackendTestCase] = [] # tp-only
if current_platform.is_cuda():
MODELS_FP8 = [
ModelBackendTestCase(
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
]
MODELS_FP4 = [
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
]
# TP only
MODELS = [
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
),
]
elif current_platform.is_rocm():
MODELS_FP8 = [
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
),
]
# TODO(luka) test both in nightly
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
# quant_fp4 only has the custom impl
+ list(flat_product(MODELS_FP4, [""])),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
custom_ops=custom_ops_list,
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(compilation_config, model_name, **model_kwargs)
matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 1, log_holder.text
assert int(matches[0]) == attention_fusions
# TODO(luka) test both in nightly
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
for op_list in itertools.product(*custom_ops_lists):
yield ",".join(op_list)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, model_kwargs, backend, "
"attention_fusions, allreduce_fusions, custom_ops",
# Toggle RMSNorm and QuantFP8 for FP8 models
list(
flat_product(
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
)
)
# Toggle RMSNorm for FP4 models and unquant models
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
@pytest.mark.skipif(
not current_platform.is_cuda()
or not has_flashinfer()
or not current_platform.has_device_capability(90),
reason="allreduce+rmsnorm fusion requires flashinfer",
)
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
inductor_graph_partition: bool,
caplog_mp_spawn,
monkeypatch,
):
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("Inductor graph partition requires torch>=2.9")
custom_ops_list = custom_ops.split(",") if custom_ops else []
if inductor_graph_partition:
mode = CUDAGraphMode.FULL_AND_PIECEWISE
splitting_ops: list[str] | None = None
else:
mode = CUDAGraphMode.FULL_DECODE_ONLY
splitting_ops = []
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
# To capture subprocess logs, we need to know whether spawn or fork is used.
# Force spawn as it is more general.
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
compilation_config = CompilationConfig(
# Testing properties
use_inductor_graph_partition=inductor_graph_partition,
cudagraph_mode=mode,
custom_ops=custom_ops_list,
splitting_ops=splitting_ops,
# Common
level=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(
enable_attn_fusion=True,
enable_noop=True,
enable_fi_allreduce_fusion=True,
),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)
with caplog_mp_spawn(logging.DEBUG) as log_holder:
run_model(
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
)
matches = re.findall(
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert int(matches[0]) == attention_fusions
assert int(matches[1]) == attention_fusions
matches = re.findall(
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
log_holder.text,
)
assert len(matches) == 2, log_holder.text
assert int(matches[0]) == allreduce_fusions
assert int(matches[1]) == allreduce_fusions
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
compilation_config = (
compile_config
if isinstance(compile_config, CompilationConfig)
else CompilationConfig(level=compile_config)
)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
# Allow override from model_kwargs
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
# No cudagraphs by default
if compilation_config.cudagraph_mode is None:
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
llm = LLM(
model=model,
compilation_config=compilation_config,
**model_kwargs,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -7,7 +7,7 @@ import torch
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.compilation.pass_manager import PostGradPassManager
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
# dummy custom pass that doesn't inherit
@ -42,7 +42,8 @@ class ProperPass(InductorPass):
],
)
def test_pass_manager_uuid(callable):
config = VllmConfig()
# Some passes need dtype to be set
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
pass_manager = PostGradPassManager()
pass_manager.configure(config)

View File

@ -18,6 +18,8 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@ -42,9 +44,7 @@ prompts = [
class TestModel(torch.nn.Module):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
class TestQuantModel(torch.nn.Module):
def __init__(
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
):
def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.vllm_config = vllm_config
self.vllm_config = get_current_vllm_config()
self.gate_proj = torch.nn.Parameter(
torch.empty((intermediate_size, hidden_size)), requires_grad=False
)
@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
initialize_model_parallel(tensor_model_parallel_size=world_size)
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
compilation_config = CompilationConfig(
pass_config=PassConfig(
enable_sequence_parallelism=True,
enable_fusion=enable_fusion,
enable_noop=True,
)
) # NoOp needed for fusion
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
vllm_config.model_config = ModelConfig(
model_config = ModelConfig(
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
vllm_config = VllmConfig(
model_config=model_config,
device_config=device_config,
compilation_config=compilation_config,
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
with set_current_vllm_config(vllm_config):
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
assert (
sequence_parallelism_pass.compilation_config.splitting_ops
== vllm_config.compilation_config.splitting_ops
)
assert (
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
== vllm_config.compilation_config.use_inductor_graph_partition
)
passes_for_backend: list[VllmInductorPass] = [
noop_pass,
sequence_parallelism_pass,
]
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
if enable_fusion:
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)
passes_for_backend.append(cleanup_pass)
passes_for_backend.append(cleanup_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
model = test_model_cls(hidden_size, hidden_size * 2)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
compiled_model_no_func(hidden_states, residual)
compiled_model_func = torch.compile(model, backend=backend_func)
compiled_model_func(hidden_states, residual)
assert sequence_parallelism_pass.matched_count == 1
assert sequence_parallelism_pass.matched_count == 1
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In pre-nodes, all reduce should be there,
# reduce scatter and all gather should not
backend_no_func.check_before_ops(model.ops_in_model_before())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
# In post-nodes, reduce scatter and all gather should be there,
# all reduce should not
backend_no_func.check_after_ops(model.ops_in_model_after())
# check if the functionalization pass is applied
for op in model.ops_in_model():
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
# check if the functionalization pass is applied
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in model.ops_in_model():
if is_func(node, op):
found[op] = True
assert all(found[op] for op in model.ops_in_model())

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import contextlib
import pathlib
from copy import deepcopy
from tblib import pickling_support
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate
# failures from tests running in a subprocess.
# This should be run before any custom exception subclasses are defined.
@ -40,7 +43,7 @@ from transformers import (
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
from vllm import LLM, SamplingParams
from vllm import LLM, SamplingParams, envs
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
yield caplog
@pytest.fixture()
def caplog_mp_fork():
"""
This fixture enables capturing logs from a forked MP subprocess.
It should be used in conjunction with caplog_vllm.
By default, subprocess logs do not go through the parent process.
We instead create a queue listener in the parent process which
forwards logs to the logger's other handlers, and add a QueueHandler
to the root logger. Forked subprocesses will inherit the root logger
and pass their messages to the queue, which the listener will forward
to the root logger, which can be captured by caplog.
Note that this workaround only works for fork; with spawn, the subprocess
reinitializes logging and does not automatically inherit the queue.
We'd have to manually pass the queue to the subprocess at the spawn point.
See caplog_mp_spawn below.
"""
@contextlib.contextmanager
def ctx():
import logging.handlers
import multiprocessing as mp
logger_queue: mp.Queue[logging.LogRecord] = mp.Queue()
logger = logging.getLogger()
handlers = logger.handlers
# The listener works on a background thread, not inherited by the child.
queue_listener = logging.handlers.QueueListener(logger_queue, *handlers)
queue_listener.start()
# Add queue handler after creating the listener to avoid cycle
logger.addHandler(logging.handlers.QueueHandler(logger_queue))
yield
queue_listener.stop()
return ctx
class LogHolder:
def __init__(self):
self.text = None
@pytest.fixture()
def caplog_mp_spawn(tmp_path, monkeypatch):
"""
This fixture enables capturing logs from a forked MP subprocess.
It does not require caplog_vllm (but it only contains logs from the child).
By default, subprocess logs do not go through the parent process.
We instead add a FileHandler to the config so the spawned child process
writes its logs to a temp file.
In the parent, we read the file and return the contents.
Note: this method could be extended to fork by either reconfiguring logging
in the parent or using a SocketHandler:
https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501
"""
@contextlib.contextmanager
def ctx(level: int | str):
from vllm.logger import DEFAULT_LOGGING_CONFIG
config_path = tmp_path / "vllm_logging_config.json"
log_path = tmp_path / "vllm.log"
log_holder = LogHolder()
config = deepcopy(DEFAULT_LOGGING_CONFIG)
if envs.VLLM_LOGGING_CONFIG_PATH:
path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH)
assert path.exists()
config = json.loads(path.read_text())
config["loggers"]["vllm"]["handlers"] += ["vllm_file"]
config["handlers"]["vllm_file"] = {
"class": "logging.FileHandler",
"formatter": "vllm",
"level": level,
"filename": log_path.as_posix(),
}
config_path.write_text(json.dumps(config))
with monkeypatch.context() as monkeypatch_ctx:
monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix())
monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1")
yield log_holder
log_holder.text = log_path.read_text()
return ctx
@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context

View File

@ -6,7 +6,7 @@ import torch
from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -70,38 +70,6 @@ def test_rms_norm(
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_poly_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
layer = PolyNorm().to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
layer.bias.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
ref_out = layer.forward_native(x)
out = layer(x)
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
opcheck(
torch.ops._C.poly_norm,
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon),
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)

View File

@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
.clamp(fp8_traits_min, fp8_traits_max)
.to(FP8_DTYPE)
)
return ref_out, ref_scale.view((1,))
return ref_out, ref_scale.view((1, 1))
def native_w8a8_block_matmul(

View File

@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content():
assert call_args[1] == "test-streaming-full-text"
assert call_args[2] == " (streaming complete)"
assert call_args[5] == "streaming_complete"
# Add vllm prefix to make sure logs go through the vllm logger
test_logger = init_logger("vllm.test_logger")
def mp_function(**kwargs):
# This function runs in a subprocess
test_logger.warning("This is a subprocess: %s", kwargs.get("a"))
test_logger.error("This is a subprocess error.")
test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b"))
def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork):
with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork():
import multiprocessing
ctx = multiprocessing.get_context("fork")
p = ctx.Process(
target=mp_function,
name=f"SubProcess{1}",
kwargs={"a": "AAAA", "b": "BBBBB"},
)
p.start()
p.join()
assert "AAAA" in caplog_vllm.text
assert "BBBBB" in caplog_vllm.text
def test_caplog_mp_spawn(caplog_mp_spawn):
with caplog_mp_spawn(logging.DEBUG) as log_holder:
import multiprocessing
ctx = multiprocessing.get_context("spawn")
p = ctx.Process(
target=mp_function,
name=f"SubProcess{1}",
kwargs={"a": "AAAA", "b": "BBBBB"},
)
p.start()
p.join()
assert "AAAA" in log_holder.text
assert "BBBBB" in log_holder.text

View File

@ -6,6 +6,7 @@ import contextlib
import copy
import functools
import importlib
import itertools
import json
import os
import random
@ -15,7 +16,7 @@ import sys
import tempfile
import time
import warnings
from collections.abc import Callable
from collections.abc import Callable, Iterable
from contextlib import ExitStack, contextmanager, suppress
from multiprocessing import Process
from pathlib import Path
@ -1261,3 +1262,23 @@ def check_answers(
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok >= accept_rate
def flat_product(*iterables: Iterable[Any]):
"""
Flatten lists of tuples of the cartesian product.
Useful when we want to avoid nested tuples to allow
test params to be unpacked directly from the decorator.
Example:
flat_product([(1, 2), (3, 4)], ["a", "b"]) ->
[
(1, 2, "a"),
(1, 2, "b"),
(3, 4, "a"),
(3, 4, "b"),
]
"""
for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized))

View File

@ -40,7 +40,7 @@ from vllm.utils import (
unique_filepath,
)
from ..utils import create_new_process_for_each_test
from ..utils import create_new_process_for_each_test, flat_product
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
@ -771,3 +771,25 @@ def test_unique_filepath():
paths.add(path)
assert len(paths) == 10
assert len(list(Path(temp_dir).glob("*.txt"))) == 10
def test_flat_product():
# Check regular itertools.product behavior
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
assert result1 == [
(1, "a"),
(1, "b"),
(2, "a"),
(2, "b"),
(3, "a"),
(3, "b"),
]
# check that the tuples get flattened
result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
assert result2 == [
(1, 2, "a", 5, 6),
(1, 2, "b", 5, 6),
(3, 4, "a", 5, 6),
(3, 4, "b", 5, 6),
]

View File

@ -34,15 +34,21 @@ else
fi
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
MODEL_NAMES=${MODEL_NAMES:-}
if [[ -n "$MODEL_NAMES" ]]; then
MODELS=("$MODEL_NAMES")
else
MODELS=(
"Qwen/Qwen3-0.6B"
)
fi
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
@ -130,7 +136,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
@ -171,7 +177,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $DECODER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
@ -200,7 +206,7 @@ run_tests_for_model() {
done
# Build the command for the proxy server with all the hosts and ports
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
PROXY_CMD="python3 ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
# Add all prefill hosts and ports
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
@ -219,7 +225,7 @@ run_tests_for_model() {
# Run lm eval for this model
echo "Running tests for $model_name"
TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
TEST_MODEL=$model_name python3 -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
# Clean up before running next model
cleanup_instances

View File

@ -12,7 +12,12 @@ FILTER = "exact_match,strict-match"
RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59}
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
"deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
}
SIMPLE_PROMPT = (
"The best part about working on vLLM is that I got to meet so many people across "

View File

@ -76,7 +76,8 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
parser.add_argument("--host", type=str, default="127.0.0.1")
# For prefiller instances
parser.add_argument(

View File

@ -0,0 +1,40 @@
#!/usr/bin/env bash
set -euo pipefail
# Utility to run integration tests sequentially with varying TP configurations.
SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
# Define test configurations
configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
)
run_tests() {
local label=$1
local extra_env=$2
echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
# Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
exit 1
fi
done
echo "✅ All ${label} tests passed!"
}
# Run tests
run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
else
echo "FLASHINFER not set, skipping FLASHINFER runs."
fi

View File

@ -11,7 +11,7 @@ PROMPT = "Hello my name is Robert and I"
@pytest.fixture(scope="module")
def llm() -> LLM:
return LLM(MODEL, enforce_eager=True)
return LLM(MODEL, enforce_eager=True, dtype="half")
def test_n_gt_1(llm):

View File

@ -339,18 +339,6 @@ def fused_add_rms_norm(
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def poly_norm(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
) -> None:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
input_contiguous = input.contiguous()
torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon)
def apply_repetition_penalties_torch(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
@ -1507,7 +1495,7 @@ def scaled_fp8_quant(
output, input, scale, scale_ub
)
else:
scale = torch.empty(1, device=input.device, dtype=torch.float32)
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"

View File

@ -17,10 +17,14 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
FP8_DTYPE = current_platform.fp8_dtype()
@ -41,11 +45,8 @@ else:
logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
if hasattr(torch.ops._C, "scaled_fp4_quant"):
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern:
@ -669,33 +670,24 @@ class AllReduceRMSNormPattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def get_inputs(self):
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
input, weight = self.rmsnorm_matcher.inputs()
return [input, rms_result, weight]
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
):
def pattern(input: torch.Tensor, weight: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_OP,
result=rms_result,
input=allreduce_output,
weight=weight,
epsilon=self.epsilon,
)
# rms_result, allreduce_output
return rms[1], allreduce_output
rms = self.rmsnorm_matcher(allreduce_output, weight)
def replacement(
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
):
return rms, allreduce_output
def replacement(input: torch.Tensor, weight: torch.Tensor):
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@ -733,29 +725,19 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def get_inputs(self):
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
return [
residual,
input,
weight,
]
input, residual, weight = self.rmsnorm_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight]
def register(self, pm_pass: PatternMatcherPass):
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
# input, residual
return rms[1], rms[2]
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
return rms, residual
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
@ -779,6 +761,18 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
)
# Same pattern, but only return the output and not residual
# (helpful for end of graph where residual is not used again)
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
pm.register_replacement(
first_return_only(pattern),
first_return_only(replacement),
self.get_inputs(),
pm.fwd_only,
pm_pass,
)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
@ -799,60 +793,37 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty(
[1, 8, 4], device=self.device, dtype=self.dtype
)
quant_result = torch.empty(
[1, 8, 4], device=self.device, dtype=self.quant_dtype
)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
input, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
# input goes through allreduce first, always 16-bit
return [input.to(self.dtype), weight, scale]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(
RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon,
)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant, _ = self.quant_matcher(rms, scale)
return quant, all_reduce
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=rmsnorm_out_tuple[1],
scale=scale,
)
# quant_out, allreduce_output
return quant_out_tuple[1], all_reduce
def replacement(
input: torch.Tensor,
result_rms: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
@ -892,64 +863,42 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
input, residual, weight = self.rmsnorm_matcher.inputs()
_, scale = self.quant_matcher.inputs()
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty(
[4, 4], device=self.device, dtype=self.quant_dtype
)
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
return [
quant_result,
residual,
input,
weight,
scale,
]
# input goes through allreduce first, always 16-bit
return [residual, input.to(self.dtype), weight, scale]
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
allreduce_output = tensor_model_parallel_all_reduce(input)
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant, _ = self.quant_matcher(rms, scale)
fused_add_rmsnorm_out_tuple = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale,
)
# quant_out, allreduce_output
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
return quant, res
def replacement(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
quant_out=result_quant,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
@ -986,14 +935,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
rmsnorm_result = torch.empty(
[1, 16, 16], device=self.device, dtype=self.dtype
)
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
input_global_scale = torch.empty(
[1, 1], device=self.device, dtype=torch.float32
@ -1001,36 +947,21 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
return [
input,
rmsnorm_result,
quant_result,
weight,
input_global_scale,
output_scale,
]
return [input, quant_result, weight, input_global_scale, output_scale]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(
RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon,
)
rms = self.rmsnorm_matcher(all_reduce, weight)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rmsnorm_out_tuple[1],
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
)
@ -1040,13 +971,13 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
def replacement(
input: torch.Tensor,
result_rms: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
@ -1090,6 +1021,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
@ -1121,28 +1053,17 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
input_global_scale: torch.Tensor,
):
allreduce_output = tensor_model_parallel_all_reduce(input)
fused_add_rmsnorm_out_tuple = auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
input=rms,
output_scale=output_scale,
input_scale=input_global_scale,
)
# quant_out, allreduce_output, output_scale
return (
quant_out_tuple[1],
fused_add_rmsnorm_out_tuple[2],
quant_out_tuple[2],
)
return quant_out_tuple[1], residual, quant_out_tuple[2]
def replacement(
quant_result: torch.Tensor,

View File

@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -92,13 +93,19 @@ class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
self.rmsnorm_matcher = (
MatcherRMSNorm(epsilon)
if not key.fused_add
else MatcherFusedAddRMSNorm(epsilon)
)
self.quant_matcher = MatcherQuantFP8(key.quant)
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
return self.quant_matcher(result_rms, scale)[0]
# result
return at2[1]
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale
)
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, _ = self.quant_matcher(result_rms, scale)
# result, residual
return at1[1], at[2]
return result, residual
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
scale: torch.Tensor,
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
# input, weight, residual
*self.rmsnorm_matcher.inputs(),
self.quant_matcher.inputs()[1], # scale
]
pm.register_replacement(
@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon,
)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor):
result_rms = self.rmsnorm_matcher(input, weight)
# result, scale
return at2[1], at2[2]
return self.quant_matcher(result_rms)
def replacement(
result: torch.Tensor,
result_rms: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
def replacement(input: torch.Tensor, weight: torch.Tensor):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, scale
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
at = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
at1 = auto_functionalized(
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
)
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
result, scale = self.quant_matcher(result_rms)
# result, residual, scale
return at1[1], at[2], at1[2]
return result, residual, scale
def replacement(
result: torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
):
# In case we're matching native rms-norm, conversions might be
# optimized out. We convert here just to be safe.
input = input.to(dtype=self.model_dtype)
result = torch.empty_like(input, dtype=self.quant_dtype)
scale = self.quant_matcher.make_scale(input)
at = auto_functionalized(
self.FUSED_OP,
result=result,
@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
# result, residual, scale
return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1), # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
self.rmsnorm_matcher.inputs(),
pm.fwd_only,
pm_pass,
)
@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
pass_name="rmsnorm_quant_fusion_pass"
)
# Make sure fused add patterns are before simple rms norm,
# as the latter is a subset of the former in torch ops
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns
)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
self.dump_patterns(config, self.patterns)
@VllmInductorPass.time_and_log

View File

@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
@ -20,7 +22,9 @@ from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .fx_utils import is_func
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
logger = init_logger(__name__)
@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC):
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
gm = trace_fn(*args, **kwargs)
for process_fx in process_fx_fns:
process_fx(gm)
return gm
return wrapped
@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
@staticmethod
def remove_noop_permutes(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if not is_func(node, torch.ops.aten.permute.default):
continue
dims = node.args[1]
if any(dim != i for i, dim in enumerate(dims)):
continue
# this is now an identity op, remove
node.replace_all_uses_with(node.args[0])
gm.graph.erase_node(node)
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(layer, quant_key, dtype)
self.quant_matcher = MatcherQuantFP8(quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(
@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(
@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size]
)
at2 = auto_functionalized(
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
)
return at2[1]
return self.quant_matcher(attn_out_view, scale)[0]
def replacement(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
output_attn: torch.Tensor,
output_quant: torch.Tensor,
scale: torch.Tensor,
):
# attn output in quant_dtype
@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v
self.empty(
5, self.num_heads, self.head_size, dtype=self.dtype
), # attn_output
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
self.empty(5, self.num_heads, self.head_size), # q
self.empty(5, self.num_heads, self.head_size), # k
self.empty(5, self.num_heads, self.head_size), # v
self.empty(5, self.num_heads, self.head_size), # attn_output
empty_fp32(1, 1), # scale
]
@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)
@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
replacement,
inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
pm.fwd_only,
AttentionQuantPattern.fx_view_to_reshape,
AttentionQuantPattern.remove_noop_permutes,
),
pm_pass,
)

View File

@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload
from torch._ops import OpOverload, OpOverloadPacket
def is_func(node: fx.Node, target) -> bool:
@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
# An auto-functionalization-aware utility for finding nodes with a specific op
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
# Also handles op overload packets and finds all overloads
def find_op_nodes(
op: OpOverload | OpOverloadPacket, graph: fx.Graph
) -> Iterator[fx.Node]:
if isinstance(op, OpOverloadPacket):
for overload in op.overloads():
overload_op = getattr(op, overload)
yield from find_op_nodes(overload_op, graph)
return
assert isinstance(op, OpOverload)
if not op._schema.is_mutable:
yield from graph.find_nodes(op="call_function", target=op)

View File

@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
from torch._higher_order_ops import auto_functionalized
from torch._ops import OpOverload
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
_normalize_quant_group_shape,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kNvfp4Quant,
)
from vllm.platforms import current_platform
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
}
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
config = get_current_vllm_config()
self.model_dtype = config.model_config.dtype if config.model_config else None
self.device = config.device_config.device if config.device_config else None
self.enabled = enabled
self.forward = self.forward_custom if enabled else self.forward_native
@abstractmethod
def forward_custom(self, *args, **kws):
pass
@abstractmethod
def forward_native(self, *args, **kws):
pass
def __call__(self, *args, **kws):
return self.forward(*args, **kws)
def empty(self, *args, **kws):
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
def empty_f32(self, *args, **kws):
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
def inputs(self) -> list[torch.Tensor]:
"""Utility for inputs to the pattern"""
raise NotImplementedError
class MatcherRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
return [input, weight]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
result = torch.empty_like(input)
_, result = auto_functionalized(
RMS_OP,
result=result,
input=input,
weight=weight,
epsilon=self.epsilon,
)
return result
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight
)
class MatcherFusedAddRMSNorm(MatcherCustomOp):
def __init__(self, epsilon: float, enabled: bool | None = None):
if enabled is None:
enabled = RMSNorm.enabled()
super().__init__(enabled)
self.epsilon = epsilon
def inputs(self):
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
weight = self.empty(16)
residual = self.empty(5, 16)
return [input, weight, residual]
def forward_custom(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
_, result, residual = auto_functionalized(
RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
return result, residual
def forward_native(
self,
input: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return RMSNorm.forward_static(
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
)
class MatcherQuantFP8(MatcherCustomOp):
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
if enabled is None:
enabled = QuantFP8.enabled()
super().__init__(enabled)
self.quant_key = quant_key
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
self.QUANT_OP = QUANT_OPS[quant_key]
assert quant_key.dtype == current_platform.fp8_dtype(), (
"Only QuantFP8 supported by"
)
assert quant_key.scale2 is None
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
def forward_custom(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
result = torch.empty(
input.shape, device=input.device, dtype=self.quant_key.dtype
)
if self.quant_key.scale.static:
assert scale is not None
_, result = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale
)
return result, scale
else:
assert scale is None
scale = self.make_scale(input)
_, result, scale = auto_functionalized(
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
)
return result, scale
def forward_native(
self,
input: torch.Tensor,
scale: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.quant_fp8(input, scale)
def make_scale(self, input: torch.Tensor):
normalized_group_shape = _normalize_quant_group_shape(
input, self.quant_key.scale.group_shape
)
scale_shape = (
input.shape[0] // normalized_group_shape[0],
input.shape[1] // normalized_group_shape[1],
)
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 16)
if self.quant_key.scale.static:
return [input, self.empty_f32(1, 1)]
return [input]

View File

@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
import depyf
path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path)
global context_manager
context_manager = depyf.prepare_debug(path.as_posix())
context_manager.__enter__()

View File

@ -5,7 +5,7 @@ import functools
from torch import fx as fx
from vllm import envs
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import set_env_var
@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass):
def configure(self, config: VllmConfig):
self.pass_config = config.compilation_config.pass_config
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
# Set the current vllm config to allow tracing CustomOp instances
with set_current_vllm_config(config, check_compile=False):
if self.pass_config.enable_noop:
self.passes += [NoOpEliminationPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_sequence_parallelism:
self.passes += [SequenceParallelismPass(config)]
if self.pass_config.enable_async_tp:
self.passes += [AsyncTPPass(config)]
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
if self.pass_config.enable_fi_allreduce_fusion:
self.passes += [AllReduceFusionPass(config)]
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
if self.pass_config.enable_fusion:
self.passes += [RMSNormQuantFusionPass(config)]
self.passes += [ActivationQuantFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
if self.pass_config.enable_attn_fusion:
self.passes += [AttnFusionPass(config)]
# needs a functional graph
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph

View File

@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass):
f" please add to dump_patterns if there are any errors.\n\n"
f"from torch._higher_order_ops.auto_functionalize import "
f"auto_functionalized as auto_functionalized\n"
f"from torch._inductor.pattern_matcher import *",
f"from torch._inductor.pattern_matcher import *\n"
f"vllm = torch.ops.vllm",
file=f,
)

View File

@ -1403,8 +1403,15 @@ class EngineArgs:
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
)
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size
if self.data_parallel_backend == "ray" and (
envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
):
# Data parallel size defaults to 1 if DP ranks are spanning
# multiple nodes
data_parallel_size_local = 1
else:
# Otherwise local DP size defaults to global DP size if not set
data_parallel_size_local = self.data_parallel_size
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.

View File

@ -110,7 +110,7 @@ class EngineClient(ABC):
@abstractmethod
async def stop_profile(self) -> None:
"""Start profiling the engine"""
"""Stop profiling the engine"""
...
@abstractmethod

View File

@ -139,7 +139,7 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_RAY_DP_PACK_STRATEGY: str = "strict"
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_MXFP4_USE_MARLIN: bool | None = None
VLLM_V0_USE_OUTLINES_CACHE: bool = False
@ -1039,6 +1039,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# for non-master nodes, allocate as many DP ranks as can fit;
# - "strict":
# allocate exactly data-parallel-size-local DP ranks to each picked node;
# - "span":
# Should be used only when a single DP rank requires multiple nodes.
# allocate one DP rank over as many nodes as required for set world_size;
# This environment variable is ignored if data-parallel-backend is not Ray.
"VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv(
"VLLM_RAY_DP_PACK_STRATEGY", "strict"

View File

@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}

View File

@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}

View File

@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}

View File

@ -46,6 +46,11 @@ def is_rocm_aiter_moe_enabled() -> bool:
)
@cache
def use_mxfp4_aiter_moe() -> bool:
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
@cache
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
return (
@ -487,6 +492,8 @@ def rocm_aiter_fused_experts(
assert quant_config.w1_scale is not None
assert quant_config.w2_scale is not None
quant_method = QuantMethod.BLOCK_128x128.value
elif quant_config.use_fp8_w8a8 and quant_config.per_out_ch_quant:
quant_method = QuantMethod.PER_TOKEN.value
elif quant_config.use_fp8_w8a8:
# Currently only per tensor quantization method is enabled.
quant_method = QuantMethod.PER_TENSOR.value

View File

@ -58,22 +58,6 @@ def fused_add_rms_norm(
return x, residual
def poly_norm(
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return out
def rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
@ -178,14 +162,11 @@ class RMSNorm(CustomOp):
self.variance_size_override = (
None if var_hidden_size == hidden_size else var_hidden_size
)
weight_dtype = dtype or torch.get_default_dtype()
self.has_weight = has_weight
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
@ -195,46 +176,68 @@ class RMSNorm(CustomOp):
with_fused_add=True, dtype=weight_dtype
)
@staticmethod
def forward_static(
x: torch.Tensor,
variance_epsilon: float,
hidden_size: int,
orig_dtype: torch.dtype,
weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None,
variance_size_override: int | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
x = x.to(torch.float32)
if residual is not None:
# residual promoted f16->f32 automatically,
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
residual = x.to(orig_dtype)
if x.shape[-1] != hidden_size:
raise ValueError(
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
)
if variance_size_override is None:
x_var = x
else:
if hidden_size < variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{variance_size_override}, but found: {hidden_size}"
)
x_var = x[:, :, :variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype)
if weight is not None:
x = x * weight
if residual is None:
return x
else:
return x, residual
def forward_native(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError(
"Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}"
)
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}"
)
x_var = x[:, :, : self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:
return x, residual
return self.forward_static(
x,
self.variance_epsilon,
self.hidden_size,
x.dtype,
self.weight.data if self.has_weight else None,
residual,
self.variance_size_override,
)
def forward_cuda(
self,
@ -366,53 +369,6 @@ class GemmaRMSNorm(CustomOp):
return self.forward_native(x, residual)
@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def __init__(
self,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
self.bias = torch.nn.Parameter(torch.zeros(1))
self.variance_epsilon = eps
def _norm(self, x):
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype = x.dtype
x_float = x.to(torch.float32)
output = (
self.weight[0] * self._norm(x_float**3)
+ self.weight[1] * self._norm(x_float**2)
+ self.weight[2] * self._norm(x_float)
+ self.bias
)
return output.to(orig_dtype)
def forward_cuda(
self,
x: torch.Tensor,
) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
class LayerNorm(nn.Module):
"""
Layer Normalization.

View File

@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
use_mxfp4_aiter_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_moe_fp8_layer_for_marlin,
@ -341,7 +342,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=self.weight_qscheme == "per_channel",
per_act_token_quant=self.input_qscheme == "per_channel",
per_out_ch_quant=self.weight_qscheme == "per_channel",
)
def apply(
@ -472,22 +474,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"not implemented. Please open an issue."
)
if not current_platform.supports_mx():
self.emulate = True
self.emulate = not current_platform.supports_mx() or not (
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
)
if self.emulate:
logger.warning_once(
"The current platform does not support native MXFP4/MXFP6 "
f"The current mode (supports_mx={current_platform.supports_mx()}, "
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
"does not support native MXFP4/MXFP6 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
)
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4/MXFP6 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision."
"The current mode supports native MoE MXFP4 computation"
)
def get_packed_dim(self, dim: int, quant_dtype: str):
@ -568,6 +570,24 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
def process_weights_after_loading(self, layer):
if self.emulate:
return
from aiter.utility.fp4_utils import e8m0_shuffle
# Pre-shuffle weight scales
s0, s1, _ = layer.w13_weight_scale.shape
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
s0, s1, _ = layer.w2_weight_scale.shape
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
torch.cuda.empty_cache()
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
@ -611,8 +631,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
)
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
@ -628,17 +646,44 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
indices_type=self.topk_indices_dtype,
)
out = fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
if not self.emulate:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
aiter_acts = {
ActivationType.No.name.lower(): ActivationType.No,
ActivationType.Silu.name.lower(): ActivationType.Silu,
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
}
assert activation in aiter_acts, (
f"Aiter CK fp4 MoE doesn't support activation {activation}"
)
out = fused_moe(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights,
topk_ids,
quant_type=QuantType.per_1x32,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=aiter_acts[activation],
doweight_stage1=False,
)
else:
from vllm.model_executor.layers.fused_moe import fused_experts
out = fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
return out

View File

@ -464,8 +464,16 @@ class Fp8LinearOp:
else:
qinput, x_scale = input_2d, input_scale
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = x_scale.numel() == 1
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(

View File

@ -735,9 +735,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
if do_sample_frames:
# here video_fps is the fps of the sampled video, and
# metadata["fps"] refers to the fps of the original video.
video_fps = sampled_fps if sampled_fps else video_processor.fps
sampled_fps = sampled_fps if sampled_fps else video_processor.fps
total_num_frames = metadata["total_num_frames"]
num_frames = int(total_num_frames / metadata["fps"] * video_fps)
num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
num_frames = min(
min(
max(num_frames, video_processor.min_frames),

View File

@ -350,6 +350,14 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
dummy_inputs=Qwen3VLDummyInputsBuilder,
)
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3VLForConditionalGeneration, self).__init__()
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
@ -376,6 +384,11 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
)
# Whether to include the gate_up_proj mapping is determined by
# the language model.
self.packed_modules_mapping = (
self.packed_modules_mapping | self.language_model.packed_modules_mapping
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors

View File

@ -345,6 +345,7 @@ class CoreEngineActorManager:
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
dp_master_ip_key = f"node:{dp_master_ip}"
nodes = sorted(
available_resources.values(), key=lambda x: dp_master_ip_key not in x
@ -355,9 +356,25 @@ class CoreEngineActorManager:
dp_master_ip,
)
device_str = current_platform.ray_device_key
n_node_devices: list[int] = [
int(node_resources[device_str])
for node_resources in nodes
if device_str in node_resources
]
assert n_node_devices, f"No {device_str} found in Ray cluster."
max_device_per_node = max(n_node_devices)
pack_strategy = envs.VLLM_RAY_DP_PACK_STRATEGY
_supported_pack_strategies = ("strict", "fill", "span")
if pack_strategy not in _supported_pack_strategies:
raise ValueError(
f"{envs.VLLM_RAY_DP_PACK_STRATEGY} is not supported. "
"Make sure to set `VLLM_RAY_DP_PACK_STRATEGY` "
f"to one of {_supported_pack_strategies}"
)
all2all_backend = vllm_config.parallel_config.all2all_backend
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
if pack_strategy == "fill" and (
all2all_backend == "deepep_high_throughput"
or all2all_backend == "deepep_low_latency"
):
@ -367,12 +384,42 @@ class CoreEngineActorManager:
"does not guarantee that. "
"Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead."
)
logger.info(
"Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY",
envs.VLLM_RAY_DP_PACK_STRATEGY,
)
strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict"
if pack_strategy in ("strict", "fill"):
placement_strategy = "STRICT_PACK"
else:
placement_strategy = "PACK"
assert world_size > max_device_per_node, (
f"World size {world_size} is smaller than the "
"maximum number of devices per node "
f"{max_device_per_node}. Make sure to set "
"`VLLM_RAY_DP_PACK_STRATEGY` to `strict` or `fill`"
)
# if we need multiple nodes per dp group, we require for now that
# available nodes are homogenous
assert set(n_node_devices) == {max_device_per_node}, (
f"Nodes are not homogenous, {nodes}"
)
assert world_size % max_device_per_node == 0, (
f"For multi-node data parallel groups, world_size ({world_size}) must "
f"be a multiple of number of devices per node ({max_device_per_node})."
)
assert len(n_node_devices) * max_device_per_node >= world_size * dp_size, (
f"Not enough total available nodes ({len(n_node_devices)}) "
f"and devices per node ({max_device_per_node}) "
f"to satisfy required world size {world_size} and data parallel size "
f"{dp_size}"
)
assert dp_size_local == 1, (
f"data-parallel-size-local {dp_size_local} should be set as the "
"default (1) for VLLM_RAY_DP_PACK_STRATEGY=span. "
"The actual data-parallel-size-local will be auto determined."
)
# bundles collected for a single DP rank from multiple nodes,
# for "span" pack strategy
collected_bundles = []
for node_resources in nodes:
node_ip_keys = [
key
@ -386,14 +433,14 @@ class CoreEngineActorManager:
node_ip_key = node_ip_keys[0]
node_ip = node_ip_key.split(":")[1]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
dp_size_available = (
int(node_resources[device_str]) // world_size
if device_str in node_resources
else 0
)
n_device_on_node = int(node_resources.get(device_str, 0))
if pack_strategy == "span" and n_device_on_node != 0:
# Strictly speaking,
# dp_size_available = n_device_on_node / world_size
# and is a fraction, but we use 1 for easier processing
dp_size_available = 1
else:
dp_size_available = n_device_on_node // world_size
if node_ip == dp_master_ip:
if dp_size_available < dp_size_local:
@ -405,7 +452,7 @@ class CoreEngineActorManager:
dp_size_available,
)
dp_size_to_allocate = dp_size_local
elif strict_local_size:
elif pack_strategy == "strict":
if dp_size_available < dp_size_local:
logger.info(
"Skipping node %s as %s DP ranks could not fit, "
@ -417,15 +464,31 @@ class CoreEngineActorManager:
continue
dp_size_to_allocate = dp_size_local
else:
# for "pack_strategy" in "fill" and "span"
# we always take everything that's available
dp_size_to_allocate = dp_size_available
for i in range(dp_size_to_allocate):
bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [
{"CPU": 1.0}
]
device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}]
if pack_strategy == "span":
collected_bundles += device_bundle * n_device_on_node
assert len(collected_bundles) <= world_size, (
"collected_bundles should be <= world_size, "
f"but got {len(collected_bundles)=} and {world_size=}"
)
# we only create a placement group if we collected enough devices
if len(collected_bundles) < world_size:
continue
bundles = collected_bundles + [{"CPU": 1.0}]
collected_bundles = []
else:
bundles = device_bundle * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
strategy=placement_strategy,
bundles=bundles,
)
placement_groups.append(pg)