mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-08 23:15:01 +08:00
Compare commits
14 Commits
87b5dd8c2b
...
ba77a1ab74
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba77a1ab74 | ||
|
|
3125d79950 | ||
|
|
e33ee23ee3 | ||
|
|
b10c64c834 | ||
|
|
0925b28a8e | ||
|
|
99722d5f0e | ||
|
|
4c91a28e30 | ||
|
|
b038d9c40c | ||
|
|
2ba60ec7fe | ||
|
|
bd7157a071 | ||
|
|
be429d0cfd | ||
|
|
c253745eb8 | ||
|
|
daec4d2624 | ||
|
|
a2d5ef088a |
@ -416,8 +416,8 @@ steps:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s compile/piecewise/
|
||||
|
||||
- label: PyTorch Fullgraph Test # 20min
|
||||
timeout_in_minutes: 30
|
||||
- label: PyTorch Fullgraph Test # 22min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
@ -425,6 +425,7 @@ steps:
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
- pytest -v -s compile/test_fusions_e2e.py
|
||||
|
||||
- label: Kernels Core Operation Test # 48min
|
||||
timeout_in_minutes: 75
|
||||
@ -807,8 +808,8 @@ steps:
|
||||
# Whisper needs spawn method to avoid deadlock
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
|
||||
- label: Blackwell Test # 38 min
|
||||
timeout_in_minutes: 60
|
||||
- label: Blackwell Test # 21 min
|
||||
timeout_in_minutes: 30
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
# optional: true
|
||||
@ -821,8 +822,6 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
@ -839,15 +838,32 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
|
||||
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
- label: Blackwell Fusion Tests # 30 min
|
||||
timeout_in_minutes: 40
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/fp4/
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/
|
||||
# can affect pattern matching
|
||||
- vllm/model_executor/layers/layernorm.py
|
||||
- vllm/model_executor/layers/activation.py
|
||||
- vllm/model_executor/layers/quantization/input_quant_fp8.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
# this runner has 2 GPUs available even though num_gpus=2 is not set
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py
|
||||
|
||||
- label: Blackwell GPT-OSS Eval
|
||||
timeout_in_minutes: 60
|
||||
@ -1068,6 +1084,17 @@ steps:
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||
|
||||
- label: NixlConnector PD accuracy tests (Distributed) # 30min
|
||||
timeout_in_minutes: 30
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
|
||||
- tests/v1/kv_connector/nixl_integration/
|
||||
commands:
|
||||
- uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt
|
||||
- bash v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
|
||||
|
||||
|
||||
##### multi gpus test #####
|
||||
@ -1100,7 +1127,7 @@ steps:
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
##### H200 test #####
|
||||
- label: Distrubted Tests (H200) # optional
|
||||
- label: Distributed Tests (H200) # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/"
|
||||
@ -1108,6 +1135,8 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s tests/compile/test_async_tp.py
|
||||
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
|
||||
@ -1,155 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def polynorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
def norm(x, eps: float):
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
x = x.float()
|
||||
return (
|
||||
(
|
||||
weight[0] * norm(x**3, eps)
|
||||
+ weight[1] * norm(x**2, eps)
|
||||
+ weight[2] * norm(x, eps)
|
||||
+ bias
|
||||
)
|
||||
.to(weight.dtype)
|
||||
.view(orig_shape)
|
||||
)
|
||||
|
||||
|
||||
def polynorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.poly_norm(out, x, weight, bias, eps)
|
||||
output = out
|
||||
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_dim):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
output_naive = polynorm_naive(x, weight, bias)
|
||||
output_vllm = polynorm_vllm(x, weight, bias)
|
||||
|
||||
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
dim_range = [2048, 4096]
|
||||
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["dim", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "vllm"],
|
||||
line_names=["Naive", "vLLM"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="polynorm-perf",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(dim, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_dim = dim * 4
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_naive(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_vllm(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Intermediate size of MLP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/polnorm/",
|
||||
help="Path to save polnorm benchmark results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_dim=args.hidden_dim,
|
||||
)
|
||||
|
||||
benchmark = get_benchmark()
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
@ -148,211 +148,6 @@ fused_add_rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck.
|
||||
|
||||
_f16VecPN struct extends _f16Vec to add operations specifically required for
|
||||
polynomial normalization (poly norm).
|
||||
The original _f16Vec does not include the sum-of-powers computation or
|
||||
in-place polynomial normalization logic. */
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
|
||||
using Base = _f16Vec<scalar_t, width>;
|
||||
using Converter = typename Base::Converter;
|
||||
using T1 = typename Base::T1;
|
||||
using T2 = typename Base::T2;
|
||||
using Base::data;
|
||||
|
||||
__device__ auto sum_pows() const {
|
||||
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
float x2 = z.x * z.x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y4 = y2 * y2;
|
||||
float y6 = y4 * y2;
|
||||
|
||||
s2 += x2 + y2;
|
||||
s4 += x4 + y4;
|
||||
s6 += x6 + y6;
|
||||
}
|
||||
return std::make_tuple(s2, s4, s6);
|
||||
}
|
||||
|
||||
__device__ void poly_norm_inplace(const float w2_inv_std,
|
||||
const float w1_inv_std2,
|
||||
const float w0_inv_std3, const float bias) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
|
||||
float x2 = z.x * z.x;
|
||||
float x3 = x2 * z.x;
|
||||
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y3 = y2 * z.y;
|
||||
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
|
||||
|
||||
auto out = Converter::convert(z);
|
||||
data[i] = out.x;
|
||||
data[i + 1] = out.y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
auto [x2, x4, x6] = temp.sum_pows();
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
|
||||
out_v[id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
/* Generic poly_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x3 = x2 * x;
|
||||
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
|
||||
s_bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
@ -364,18 +159,26 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
int64_t input_stride = input.stride(-2);
|
||||
|
||||
// We cannot just use `input.stride(-2)` if the tensor is not row-major.
|
||||
// Instead, we use a 2d view to get the second-innermost stride.
|
||||
// That way the dimensions (except the last one) can be arbitrarily permuted.
|
||||
torch::Tensor input_view = input.view({-1, hidden_size});
|
||||
|
||||
int num_tokens = input_view.numel() / hidden_size;
|
||||
int64_t input_stride = input_view.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input_view));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input_view.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input_view.data_ptr<scalar_t>(),
|
||||
input_stride, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
@ -392,6 +195,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
|
||||
TORCH_CHECK(input.scalar_type() == residual.scalar_type());
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
@ -434,50 +239,3 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_POLY_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
|
||||
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
|
||||
hidden_size); \
|
||||
});
|
||||
|
||||
void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [3]
|
||||
torch::Tensor& bias, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.data_ptr() != input.data_ptr());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
/* This kernel is memory-latency bound in many scenarios.
|
||||
When num_tokens is large, a smaller block size allows
|
||||
for increased block occupancy on CUs and better latency
|
||||
hiding on global mem ops. */
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||
with packed + vectorized ops.
|
||||
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||
since we can load at most 128 bits at once in a global memory op.
|
||||
However, this requires each tensor's data to be aligned to 16
|
||||
bytes.
|
||||
*/
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,6 +229,8 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(residual.scalar_type() == input.scalar_type());
|
||||
TORCH_CHECK(weight.scalar_type() == input.scalar_type());
|
||||
int hidden_size = input.size(-1);
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
@ -92,9 +92,6 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
torch::Tensor& bias, double epsilon);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
|
||||
@ -145,7 +145,11 @@ void rms_norm_dynamic_per_token_quant(
|
||||
if (scale_ub.has_value()) {
|
||||
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||
}
|
||||
TORCH_CHECK(weight.dtype() == input.dtype());
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||
if (residual) {
|
||||
TORCH_CHECK(residual->scalar_type() == input.scalar_type());
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
|
||||
|
||||
@ -175,12 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Polynomial Normalization.
|
||||
ops.def(
|
||||
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
|
||||
"epsilon) -> ()");
|
||||
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
|
||||
@ -69,6 +69,7 @@ There are several notable differences when using Ray:
|
||||
- A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node
|
||||
- There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address`
|
||||
- There is no need to specify `--data-parallel-rpc-port`
|
||||
- When a single DP group requires multiple nodes, *e.g.* in case a single model replica needs to run on at least two nodes, make sure to set `VLLM_RAY_DP_PACK_STRATEGY="span"` in which case `--data-parallel-size-local` is ignored and will be automatically determined
|
||||
- Remote DP ranks will be allocated based on node resources of the Ray cluster
|
||||
|
||||
Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic.
|
||||
|
||||
@ -3,16 +3,22 @@
|
||||
|
||||
import weakref
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
|
||||
import depyf
|
||||
from torch import fx
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx._utils import lazy_format_graph_code
|
||||
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.compilation.pass_manager import with_pattern_match_debug
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger("vllm.tests.compile.backend")
|
||||
|
||||
|
||||
class LazyInitPass(InductorPass):
|
||||
@ -45,20 +51,32 @@ class TestBackend:
|
||||
|
||||
def __init__(self, *passes: InductorPass | Callable[[fx.Graph], None]):
|
||||
self.custom_passes = list(passes)
|
||||
compile_config = get_current_vllm_config().compilation_config
|
||||
self.inductor_config = compile_config.inductor_compile_config
|
||||
vllm_config = get_current_vllm_config()
|
||||
compile_config = vllm_config.compilation_config
|
||||
# Deepcopy to allow multiple TestBackend instances to use the same VllmConfig
|
||||
self.inductor_config = deepcopy(compile_config.inductor_compile_config)
|
||||
self.inductor_config["force_disable_caches"] = True
|
||||
self.inductor_config["post_grad_custom_post_pass"] = self.post_pass
|
||||
|
||||
if debug_dump_path := vllm_config.compile_debug_dump_path():
|
||||
logger.debug("Dumping depyf output to %s", debug_dump_path)
|
||||
self.debug_ctx = depyf.prepare_debug(debug_dump_path.as_posix())
|
||||
else:
|
||||
self.debug_ctx = nullcontext()
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs):
|
||||
self.graph_pre_compile = deepcopy(graph)
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
return compile_fx(graph, example_inputs, config_patches=self.inductor_config)
|
||||
with self.debug_ctx:
|
||||
return compile_fx(
|
||||
graph, example_inputs, config_patches=self.inductor_config
|
||||
)
|
||||
|
||||
@with_pattern_match_debug
|
||||
def post_pass(self, graph: fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
lazy_format_graph_code("graph_pre_pass", graph.owning_module)
|
||||
|
||||
VllmInductorPass.dump_prefix = 0
|
||||
for pass_ in self.custom_passes:
|
||||
@ -68,6 +86,7 @@ class TestBackend:
|
||||
VllmInductorPass.dump_prefix = None
|
||||
|
||||
self.graph_post_pass = deepcopy(graph)
|
||||
lazy_format_graph_code("graph_post_pass", graph.owning_module)
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@ -10,8 +10,6 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
@ -22,23 +20,24 @@ from ..utils import create_new_process_for_each_test
|
||||
def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
|
||||
("facebook/opt-125m", {}),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
{
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
),
|
||||
(
|
||||
"neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic",
|
||||
{
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{"dtype": torch.float16},
|
||||
),
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
("meta-llama/Llama-3.2-1B-Instruct", {}),
|
||||
]
|
||||
|
||||
if all:
|
||||
TEST_MODELS.extend(
|
||||
[
|
||||
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
|
||||
(
|
||||
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||
{"dtype": torch.float16},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# TODO: figure out why this fails.
|
||||
if False and is_quant_method_supported("gguf"): # noqa: SIM223
|
||||
TEST_MODELS.append(
|
||||
@ -83,31 +82,38 @@ def models_list(*, all: bool = True, keywords: list[str] | None = None):
|
||||
"compilation_mode",
|
||||
[CompilationMode.DYNAMO_TRACE_ONCE, CompilationMode.VLLM_COMPILE],
|
||||
)
|
||||
@pytest.mark.parametrize("model_info", models_list(all=True))
|
||||
@pytest.mark.parametrize("model, model_kwargs", models_list(all=True))
|
||||
@create_new_process_for_each_test()
|
||||
def test_full_graph(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
compilation_mode: int,
|
||||
):
|
||||
model, model_kwargs = model_info
|
||||
if (
|
||||
"w8a8" in model
|
||||
or "w8w8" in model
|
||||
and current_platform.has_device_capability((10, 0))
|
||||
):
|
||||
# int8 removed on Blackwell:
|
||||
pytest.skip("int8 support removed on Blackwell")
|
||||
|
||||
with monkeypatch.context():
|
||||
print(f"MODEL={model}")
|
||||
|
||||
run_model(compilation_mode, model, model_kwargs)
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
# TODO(luka) add other supported compilation config scenarios here
|
||||
@pytest.mark.parametrize(
|
||||
"compilation_config, model_info",
|
||||
"compilation_config, model, model_kwargs",
|
||||
[
|
||||
# additional compile sizes, only some of the models
|
||||
(
|
||||
CompilationConfig(mode=CompilationMode.VLLM_COMPILE, compile_sizes=[1, 2]),
|
||||
model,
|
||||
*model_info,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
for model_info in models_list(all=False)
|
||||
]
|
||||
+ [
|
||||
# RMSNorm + quant fusion, only 8-bit quant models
|
||||
@ -117,18 +123,19 @@ def test_full_graph(
|
||||
custom_ops=["+rms_norm"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
),
|
||||
model,
|
||||
*model_info,
|
||||
)
|
||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
for model_info in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
]
|
||||
+ [
|
||||
# Test depyf integration works
|
||||
(
|
||||
CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
debug_dump_path=tempfile.gettempdir(),
|
||||
debug_dump_path=Path(tempfile.gettempdir()),
|
||||
),
|
||||
("facebook/opt-125m", {}),
|
||||
"facebook/opt-125m",
|
||||
{},
|
||||
),
|
||||
]
|
||||
+ [
|
||||
@ -142,9 +149,9 @@ def test_full_graph(
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
compile_sizes=[1, 2],
|
||||
),
|
||||
model,
|
||||
*model_info,
|
||||
)
|
||||
for model in models_list(all=False)
|
||||
for model_info in models_list(all=False)
|
||||
if is_torch_equal_or_newer("2.9.0.dev")
|
||||
],
|
||||
)
|
||||
@ -152,16 +159,24 @@ def test_full_graph(
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_compile_config(
|
||||
compilation_config: CompilationConfig,
|
||||
model_info: tuple[str, dict[str, Any]],
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
):
|
||||
if (
|
||||
"w8a8" in model
|
||||
or "w8w8" in model
|
||||
and current_platform.has_device_capability((10, 0))
|
||||
):
|
||||
# int8 removed on Blackwell:
|
||||
pytest.skip("int8 support removed on Blackwell")
|
||||
|
||||
if compilation_config.use_inductor_graph_partition and not is_torch_equal_or_newer(
|
||||
"2.9.0.dev"
|
||||
):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
model, model_kwargs = model_info
|
||||
print(f"MODEL={model}")
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
run_model(compilation_config, model, **model_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -176,50 +191,16 @@ def test_fp8_kv_scale_compile(compilation_mode: int):
|
||||
"calculate_kv_scales": True,
|
||||
"max_model_len": 512,
|
||||
}
|
||||
run_model(compilation_mode, model, model_kwargs)
|
||||
run_model(compilation_mode, model, **model_kwargs)
|
||||
|
||||
|
||||
def test_inductor_graph_partition_attn_fusion(caplog_vllm):
|
||||
if not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
|
||||
model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
use_inductor_graph_partition=True,
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
compilation_config = (
|
||||
compile_config
|
||||
if isinstance(compile_config, CompilationConfig)
|
||||
else CompilationConfig(level=compile_config)
|
||||
)
|
||||
model_kwargs = {
|
||||
"kv_cache_dtype": "fp8",
|
||||
"max_model_len": 1024,
|
||||
}
|
||||
with (
|
||||
caplog_vllm.at_level(logging.DEBUG),
|
||||
global_force_attn_backend_context_manager(_Backend.FLASHINFER),
|
||||
):
|
||||
run_model(compilation_config, model, model_kwargs)
|
||||
|
||||
try:
|
||||
assert "Fused quantization onto 48 attention nodes" in caplog_vllm.text, (
|
||||
caplog_vllm.text
|
||||
)
|
||||
except AssertionError:
|
||||
# Note: this message is only triggered when the compilation goes
|
||||
# through the custom pass. Due to multiple layers of cache on
|
||||
# PyTorch side, the compilation of a graph may be cached such
|
||||
# that custom pass directly goes through cache. In this case,
|
||||
# we go through this branch and assert that the pass is not
|
||||
# triggered.
|
||||
assert "Fused quantization" not in caplog_vllm.text
|
||||
|
||||
|
||||
def run_model(
|
||||
compile_config: int | CompilationConfig,
|
||||
model: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
):
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -227,12 +208,17 @@ def run_model(
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
# Allow override from model_kwargs
|
||||
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
||||
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
||||
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
disable_custom_all_reduce=True,
|
||||
compilation_config=compile_config,
|
||||
compilation_config=compilation_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
@ -11,7 +11,13 @@ from vllm.compilation.fusion import RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
@ -48,8 +54,7 @@ class TestSiluMul(torch.nn.Module):
|
||||
return y
|
||||
|
||||
def example_inputs(self, num_tokens=32, hidden_size=128):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
return (torch.rand(num_tokens, hidden_size * 2, dtype=dtype),)
|
||||
return (torch.rand(num_tokens, hidden_size * 2),)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
if TEST_FP8 and do_fusion:
|
||||
@ -67,15 +72,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size), dtype=dtype)
|
||||
torch.empty((intermediate_size, hidden_size))
|
||||
)
|
||||
self.norm = RMSNorm(intermediate_size, 1e-05)
|
||||
self.norm.weight = torch.nn.Parameter(
|
||||
torch.ones(intermediate_size, dtype=dtype)
|
||||
)
|
||||
self.norm.weight = torch.nn.Parameter(torch.ones(intermediate_size))
|
||||
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
@ -112,9 +113,8 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
return norm_output, residual_output
|
||||
|
||||
def example_inputs(self, batch_size=8, hidden_size=16, seq_len=16):
|
||||
dtype = torch.float16 if TEST_FP8 else torch.float32
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size))
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size))
|
||||
return (hidden_states, residual)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -145,10 +145,9 @@ class TestRotaryEmbedding(torch.nn.Module):
|
||||
return q_rotated, k_rotated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64):
|
||||
dtype = torch.float16
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
q = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||
k = torch.randn(num_tokens, head_dim, dtype=dtype)
|
||||
q = torch.randn(num_tokens, head_dim)
|
||||
k = torch.randn(num_tokens, head_dim)
|
||||
return (positions, q, k)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -166,7 +165,7 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
self.hidden_size = head_dim * num_heads
|
||||
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
self.hidden_size, self.hidden_size * 3, bias=False, dtype=torch.float16
|
||||
self.hidden_size, self.hidden_size * 3, bias=False
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -190,10 +189,9 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
return qkv_updated
|
||||
|
||||
def example_inputs(self, num_tokens=32, head_dim=64, num_heads=4):
|
||||
dtype = torch.float16
|
||||
hidden_size = head_dim * num_heads
|
||||
positions = torch.arange(num_tokens, dtype=torch.long)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
return (positions, hidden_states)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
@ -211,48 +209,58 @@ MODELS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool):
|
||||
def test_fix_functionalization(
|
||||
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
custom_ops=["all"],
|
||||
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True),
|
||||
),
|
||||
)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
assert RMSNorm.enabled()
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
|
||||
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
passes = (
|
||||
[noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass]
|
||||
if do_fusion
|
||||
else [noop_pass, cleanup_pass]
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
backend_func = TestBackend(*passes, func_pass)
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
for op in model.ops_not_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
|
||||
@ -5,15 +5,18 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (
|
||||
FUSED_OPS,
|
||||
QUANT_OPS,
|
||||
FusedRMSQuantKey,
|
||||
RMSNormQuantFusionPass,
|
||||
)
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@ -32,6 +35,9 @@ from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
@ -45,18 +51,18 @@ class TestModel(torch.nn.Module):
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cuda_force_torch = cuda_force_torch
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
else:
|
||||
self.scale = [None for _ in range(2)]
|
||||
self.scale = [None for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(2)
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
@ -65,8 +71,12 @@ class TestModel(torch.nn.Module):
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.sqrt(x)
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(
|
||||
@ -78,24 +88,44 @@ class TestModel(torch.nn.Module):
|
||||
x3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [QUANT_OPS[self.key]]
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
x4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)],
|
||||
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)],
|
||||
]
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return (
|
||||
[QUANT_OPS[self.quant_key]]
|
||||
if self.enable_quant_fp8_custom_op
|
||||
else [torch.ops.aten.reciprocal]
|
||||
)
|
||||
|
||||
def ops_in_model_before_partial(self):
|
||||
return (
|
||||
[RMS_OP, RMS_ADD_OP]
|
||||
if self.enable_rms_norm_custom_op
|
||||
else [torch.ops.aten.rsqrt]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("num_tokens", [257])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.parametrize("static", [True, False])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
||||
# cuda_force_torch used to test torch code path on platforms that
|
||||
# cutlass_fp8_supported() == True.
|
||||
@pytest.mark.parametrize(
|
||||
@ -105,19 +135,32 @@ class TestModel(torch.nn.Module):
|
||||
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
|
||||
)
|
||||
def test_fusion_rmsnorm_quant(
|
||||
dtype, hidden_size, num_tokens, eps, static, cuda_force_torch
|
||||
dtype,
|
||||
hidden_size,
|
||||
num_tokens,
|
||||
eps,
|
||||
static,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
cuda_force_torch,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(1)
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
custom_ops=custom_ops,
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
)
|
||||
),
|
||||
)
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
@ -126,31 +169,39 @@ def test_fusion_rmsnorm_quant(
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
model = TestModel(hidden_size, eps, static, cuda_force_torch)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
model_fused = torch.compile(model, backend=backend)
|
||||
result_fused = model_fused(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
model_unfused = torch.compile(model, backend=backend2)
|
||||
result_unfused = model_unfused(x)
|
||||
|
||||
# Higher tol for dynamic, even higher for bfloat16
|
||||
if static:
|
||||
ATOL, RTOL = (1e-3, 1e-3)
|
||||
elif dtype == torch.float16:
|
||||
if dtype == torch.float16:
|
||||
ATOL, RTOL = (2e-3, 2e-3)
|
||||
else:
|
||||
ATOL, RTOL = (1e-2, 1e-2)
|
||||
|
||||
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||
torch.testing.assert_close(result_fused, result_unfused, atol=ATOL, rtol=RTOL)
|
||||
|
||||
assert fusion_pass.matched_count == 2
|
||||
|
||||
# In pre-nodes, fp8 quant should be there and fused kernels should not
|
||||
assert fusion_pass.matched_count == 3
|
||||
backend.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, fused kernels should be there and fp8 quant should not
|
||||
backend.check_before_ops(
|
||||
model.ops_in_model_before_partial(), fully_replaced=False
|
||||
)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# If RMSNorm custom op is disabled (native/torch impl used),
|
||||
# there's a risk that the fused add doesn't get included in the
|
||||
# replacement and only the rms part gets fused with quant.
|
||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||
if not enable_rms_norm_custom_op:
|
||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||
assert n_add_nodes(backend.graph_post_pass) == 2
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
@ -17,6 +18,7 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -25,8 +27,8 @@ from vllm.distributed.parallel_state import (
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
GroupShape,
|
||||
QuantFP8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
@ -40,13 +42,30 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm = self.norm(all_reduce)
|
||||
return norm
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(x)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = torch.mm(y, self.w[0])
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = torch.mm(y2, self.w[1])
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid)
|
||||
|
||||
z4 = torch.mm(y3, self.w[2])
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid)
|
||||
return y4
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
@ -55,44 +74,53 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm, _ = self.norm(all_reduce, residual)
|
||||
return norm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.quant_fp8 = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
torch.ops._C.static_scaled_fp8_quant(
|
||||
self.output, norm_output.contiguous(), self.scale
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
return self.output, residual_output
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
@ -100,7 +128,9 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
def ops_in_model_before(self):
|
||||
return [
|
||||
torch.ops.vllm.all_reduce.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default,
|
||||
torch.ops._C.static_scaled_fp8_quant.default
|
||||
if self.fp8_linear.quant_fp8.enabled()
|
||||
else torch.ops.aten.reciprocal.default,
|
||||
]
|
||||
|
||||
|
||||
@ -109,25 +139,48 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.output = torch.empty((token_num, hidden_size), dtype=torch.float32)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(token_num, 128)
|
||||
scale_n = hidden_size // 16
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
self.output_scale = torch.empty((rounded_m, rounded_n // 4), dtype=torch.int32)
|
||||
self.w = [torch.rand(hidden_size, hidden_size) for _ in range(3)]
|
||||
self.agscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
wgscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.alpha = [1 / (w * a) for w, a in zip(wgscale, self.agscale)]
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
|
||||
torch.ops._C.scaled_fp4_quant(
|
||||
self.output, norm_output, self.output_scale, self.scale
|
||||
wq_gen, wscale_gen = zip(
|
||||
*(scaled_fp4_quant(w, wg) for w, wg in zip(self.w, wgscale))
|
||||
)
|
||||
return self.output, residual_output, self.output_scale
|
||||
self.wq, self.wscale = list(wq_gen), list(wscale_gen)
|
||||
print(f"{self.wq=}, {self.wscale=}")
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
yq, y_scale = scaled_fp4_quant(y, self.agscale[0])
|
||||
z2 = cutlass_scaled_fp4_mm(
|
||||
yq, self.wq[0], y_scale, self.wscale[0], self.alpha[0], out_dtype=y.dtype
|
||||
)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
yq2, y_scale2 = scaled_fp4_quant(y2, self.agscale[1])
|
||||
z3 = cutlass_scaled_fp4_mm(
|
||||
yq2, self.wq[1], y_scale2, self.wscale[1], self.alpha[1], out_dtype=y2.dtype
|
||||
)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
yq3, y_scale3 = scaled_fp4_quant(y3, self.agscale[2])
|
||||
z4 = cutlass_scaled_fp4_mm(
|
||||
yq3, self.wq[2], y_scale3, self.wscale[2], self.alpha[2], out_dtype=y3.dtype
|
||||
)
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
@ -141,19 +194,19 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
"test_model, enable_quant_fp8_custom_op",
|
||||
[
|
||||
TestAllReduceRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormModel,
|
||||
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
|
||||
# TODO: Enable with torch==2.8.0
|
||||
# TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
|
||||
(TestAllReduceRMSNormModel, False),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, True),
|
||||
(TestAllReduceRMSNormStaticQuantFP8Model, False),
|
||||
(TestAllReduceFusedAddRMSNormStaticQuantFP4Model, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
not find_spec("flashinfer")
|
||||
@ -167,6 +220,8 @@ def test_all_reduce_fusion_pass_replace(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
num_processes = 2
|
||||
if (
|
||||
@ -181,7 +236,16 @@ def test_all_reduce_fusion_pass_replace(
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||
args=(
|
||||
num_processes,
|
||||
test_model,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
@ -196,6 +260,8 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
enable_rms_norm_custom_op,
|
||||
enable_quant_fp8_custom_op,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@ -217,15 +283,22 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
custom_ops = []
|
||||
if enable_rms_norm_custom_op:
|
||||
custom_ops.append("+rms_norm")
|
||||
if enable_quant_fp8_custom_op:
|
||||
custom_ops.append("+quant_fp8")
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE, custom_ops=["+rms_norm", "+quant_fp8"]
|
||||
mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops
|
||||
)
|
||||
)
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True, enable_noop=True
|
||||
)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
vllm_config.parallel_config.rank = local_rank # Setup rank for debug path
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
@ -233,24 +306,27 @@ def all_reduce_fusion_pass_on_test_model(
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
backend = TestBackend(
|
||||
noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass
|
||||
)
|
||||
|
||||
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass, cleanup_pass)
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
|
||||
token_num = batch_size * seq_len
|
||||
model = test_model_cls(hidden_size, token_num)
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
|
||||
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
residual = torch.randn((token_num, hidden_size), requires_grad=False)
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states, residual)
|
||||
|
||||
assert all_reduce_fusion_pass.matched_count == 1
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
assert all_reduce_fusion_pass.matched_count == 4, (
|
||||
f"{all_reduce_fusion_pass.matched_count=}"
|
||||
)
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
||||
|
||||
@ -6,14 +6,15 @@ import pytest
|
||||
import torch._dynamo
|
||||
|
||||
from tests.compile.backend import LazyInitPass, TestBackend
|
||||
from tests.utils import flat_product
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.compilation.post_cleanup import PostCleanupPass
|
||||
from vllm.config import (
|
||||
@ -28,21 +29,18 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: TestBackend | None = None
|
||||
backend_unfused: TestBackend | None = None
|
||||
|
||||
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
@ -104,6 +102,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
num_blocks = batch_size * max_blocks
|
||||
backend = self.attn.backend
|
||||
|
||||
# TODO(luka) use get_kv_cache_stride_order
|
||||
# Create dummy KV cache for the selected backend
|
||||
if backend == _Backend.ROCM_ATTN:
|
||||
# k/v as 1st dimention
|
||||
@ -241,26 +240,40 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
)
|
||||
|
||||
|
||||
MODELS_FP8: list[tuple[str, type]] = []
|
||||
MODELS_FP4: list[tuple[str, type]] = []
|
||||
HEADS: list[tuple[int, int]] = []
|
||||
SPLIT_ATTENTION: list[bool] = []
|
||||
BACKENDS_FP8: list[_Backend] = []
|
||||
BACKENDS_FP4: list[_Backend] = []
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS = [
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
MODELS_FP8 = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel,
|
||||
),
|
||||
)
|
||||
]
|
||||
MODELS_FP4 = [
|
||||
(
|
||||
"nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel,
|
||||
),
|
||||
)
|
||||
]
|
||||
HEADS = [(64, 8), (40, 8)]
|
||||
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
|
||||
BACKENDS_FP4 = [_Backend.FLASHINFER]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
MODELS = [
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
MODELS_FP8 = [
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
HEADS = [(32, 8), (40, 8)]
|
||||
else:
|
||||
MODELS = []
|
||||
HEADS = []
|
||||
BACKENDS = [
|
||||
_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
_Backend.ROCM_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
|
||||
@ -269,46 +282,36 @@ else:
|
||||
"batch_size", [7, 256, 533] if current_platform.is_cuda() else [8]
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("model_name, model_class", MODELS)
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
[_Backend.FLASHINFER]
|
||||
if current_platform.is_cuda()
|
||||
else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
|
||||
)
|
||||
# TODO(boyuan): test inductor graph partition on rocm
|
||||
@pytest.mark.parametrize(
|
||||
"use_inductor_graph_partition",
|
||||
[False] if current_platform.is_rocm() else [False, True],
|
||||
"backend, model_name, model_class, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(BACKENDS_FP8, MODELS_FP8, ["+quant_fp8", "-quant_fp8"]))
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(BACKENDS_FP4, MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_cuda() and not current_platform.is_device_capability((10, 0)),
|
||||
reason="On CUDA only test on SM100(Blackwell)",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Only test ROCm or CUDA"
|
||||
)
|
||||
def test_attention_quant_pattern(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
batch_size: int,
|
||||
dtype: torch.dtype,
|
||||
custom_ops: str,
|
||||
model_name: str,
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend,
|
||||
use_inductor_graph_partition: bool,
|
||||
dist_init,
|
||||
caplog_vllm,
|
||||
):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
torch.manual_seed(42)
|
||||
@ -322,8 +325,7 @@ def test_attention_quant_pattern(
|
||||
scheduler_config=SchedulerConfig(max_num_seqs=1024),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["+quant_fp8"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
custom_ops=custom_ops_list,
|
||||
),
|
||||
cache_config=CacheConfig(cache_dtype="fp8"),
|
||||
)
|
||||
@ -358,8 +360,9 @@ def test_attention_quant_pattern(
|
||||
forward_ctx = get_forward_context()
|
||||
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
# Run model directly without fusion
|
||||
# Still compile so query QuantFP8 has closer numerics
|
||||
result_unfused = torch.compile(model_unfused, fullgraph=True)(q, k, v)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
@ -414,16 +417,25 @@ def test_attention_quant_pattern(
|
||||
)
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key = model_class.quant_key
|
||||
quant_key: QuantKey = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
# Note: fully_replaced=False because query quant ops remain in graph.
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
|
||||
assert sum(attn_fusion_supported) == len(attn_fusion_supported), (
|
||||
"All layers should support attention fusion"
|
||||
)
|
||||
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
quant_op = (
|
||||
torch.ops.aten.reciprocal
|
||||
if "-quant_fp8" in custom_ops_list
|
||||
else QUANT_OPS[quant_key]
|
||||
)
|
||||
|
||||
# Note: for fp8, fully_replaced=False because query quant ops remain in graph.
|
||||
# Only output quant ops are fused into attention.
|
||||
test_backend.check_before_ops([quant_op], fully_replaced=quant_key is kNvfp4Quant)
|
||||
|
||||
# access the underlying `AttnFusionPass` on the `LazyInitPass`
|
||||
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
|
||||
|
||||
305
tests/compile/test_fusions_e2e.py
Normal file
305
tests/compile/test_fusions_e2e.py
Normal file
@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from ..utils import flat_product, multi_gpu_test
|
||||
|
||||
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: _Backend
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
|
||||
|
||||
MODELS_FP8: list[ModelBackendTestCase] = []
|
||||
MODELS_FP4: list[ModelBackendTestCase] = []
|
||||
MODELS: list[ModelBackendTestCase] = [] # tp-only
|
||||
|
||||
if current_platform.is_cuda():
|
||||
MODELS_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
]
|
||||
|
||||
MODELS_FP4 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
]
|
||||
|
||||
# TP only
|
||||
MODELS = [
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
MODELS_FP8 = [
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
]
|
||||
|
||||
# TODO(luka) test both in nightly
|
||||
CUSTOM_OPS_FP8 = ["-quant_fp8"] # , "+quant_fp8"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
# Test attention+quant_fp8 fusion with custom and torch impls of QuantFP8
|
||||
list(flat_product(MODELS_FP8, CUSTOM_OPS_FP8))
|
||||
# quant_fp4 only has the custom impl
|
||||
+ list(flat_product(MODELS_FP4, [""])),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: _Backend,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
custom_ops=custom_ops_list,
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(compilation_config, model_name, **model_kwargs)
|
||||
|
||||
matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 1, log_holder.text
|
||||
assert int(matches[0]) == attention_fusions
|
||||
|
||||
|
||||
# TODO(luka) test both in nightly
|
||||
CUSTOM_OPS_RMS_NORM = ["-rms_norm"] # , "+rms_norm"]
|
||||
|
||||
|
||||
def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
for op_list in itertools.product(*custom_ops_lists):
|
||||
yield ",".join(op_list)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, model_kwargs, backend, "
|
||||
"attention_fusions, allreduce_fusions, custom_ops",
|
||||
# Toggle RMSNorm and QuantFP8 for FP8 models
|
||||
list(
|
||||
flat_product(
|
||||
MODELS_FP8, custom_ops_product(CUSTOM_OPS_FP8, CUSTOM_OPS_RMS_NORM)
|
||||
)
|
||||
)
|
||||
# Toggle RMSNorm for FP4 models and unquant models
|
||||
+ list(flat_product(MODELS_FP4 + MODELS, CUSTOM_OPS_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda()
|
||||
or not has_flashinfer()
|
||||
or not current_platform.has_device_capability(90),
|
||||
reason="allreduce+rmsnorm fusion requires flashinfer",
|
||||
)
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: _Backend,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
inductor_graph_partition: bool,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
pytest.skip("Inductor graph partition requires torch>=2.9")
|
||||
|
||||
custom_ops_list = custom_ops.split(",") if custom_ops else []
|
||||
|
||||
if inductor_graph_partition:
|
||||
mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
splitting_ops: list[str] | None = None
|
||||
else:
|
||||
mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
splitting_ops = []
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
# To capture subprocess logs, we need to know whether spawn or fork is used.
|
||||
# Force spawn as it is more general.
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend.name)
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
# Testing properties
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
cudagraph_mode=mode,
|
||||
custom_ops=custom_ops_list,
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
level=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(
|
||||
enable_attn_fusion=True,
|
||||
enable_noop=True,
|
||||
enable_fi_allreduce_fusion=True,
|
||||
),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
run_model(
|
||||
compilation_config, model_name, tensor_parallel_size=2, **model_kwargs
|
||||
)
|
||||
matches = re.findall(
|
||||
r"fusion_attn.py:\d+] Fused quant onto (\d+) attention nodes",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == attention_fusions
|
||||
assert int(matches[1]) == attention_fusions
|
||||
|
||||
matches = re.findall(
|
||||
r"collective_fusion.py:\d+] Replaced (\d+) patterns",
|
||||
log_holder.text,
|
||||
)
|
||||
assert len(matches) == 2, log_holder.text
|
||||
|
||||
assert int(matches[0]) == allreduce_fusions
|
||||
assert int(matches[1]) == allreduce_fusions
|
||||
|
||||
|
||||
def run_model(compile_config: int | CompilationConfig, model: str, **model_kwargs):
|
||||
compilation_config = (
|
||||
compile_config
|
||||
if isinstance(compile_config, CompilationConfig)
|
||||
else CompilationConfig(level=compile_config)
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
# Allow override from model_kwargs
|
||||
model_kwargs = {"tensor_parallel_size": 1, **model_kwargs}
|
||||
model_kwargs = {"disable_custom_all_reduce": True, **model_kwargs}
|
||||
|
||||
# No cudagraphs by default
|
||||
if compilation_config.cudagraph_mode is None:
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
compilation_config=compilation_config,
|
||||
**model_kwargs,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.compilation.pass_manager import PostGradPassManager
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
|
||||
# dummy custom pass that doesn't inherit
|
||||
@ -42,7 +42,8 @@ class ProperPass(InductorPass):
|
||||
],
|
||||
)
|
||||
def test_pass_manager_uuid(callable):
|
||||
config = VllmConfig()
|
||||
# Some passes need dtype to be set
|
||||
config = VllmConfig(model_config=ModelConfig(dtype=torch.bfloat16))
|
||||
|
||||
pass_manager = PostGradPassManager()
|
||||
pass_manager.configure(config)
|
||||
|
||||
@ -18,6 +18,8 @@ from vllm.config import (
|
||||
ModelConfig,
|
||||
PassConfig,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
@ -42,9 +44,7 @@ prompts = [
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
@ -95,13 +95,11 @@ class TestModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestQuantModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, hidden_size=16, intermediate_size=32, vllm_config: VllmConfig = None
|
||||
):
|
||||
def __init__(self, hidden_size=16, intermediate_size=32):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.gate_proj = torch.nn.Parameter(
|
||||
torch.empty((intermediate_size, hidden_size)), requires_grad=False
|
||||
)
|
||||
@ -266,76 +264,84 @@ def sequence_parallelism_pass_on_test_model(
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
# configure vllm config for SequenceParallelismPass
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
compilation_config = CompilationConfig(
|
||||
pass_config=PassConfig(
|
||||
enable_sequence_parallelism=True,
|
||||
enable_fusion=enable_fusion,
|
||||
enable_noop=True,
|
||||
)
|
||||
) # NoOp needed for fusion
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8"
|
||||
vllm_config.model_config = ModelConfig(
|
||||
model_config = ModelConfig(
|
||||
model=model_name, trust_remote_code=True, dtype=dtype, seed=42
|
||||
)
|
||||
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
compilation_config=compilation_config,
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
|
||||
passes_for_backend: list[VllmInductorPass] = [noop_pass, sequence_parallelism_pass]
|
||||
with set_current_vllm_config(vllm_config):
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
|
||||
func_pass = FixFunctionalizationPass(vllm_config)
|
||||
cleanup_pass = PostCleanupPass(vllm_config)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.splitting_ops
|
||||
== vllm_config.compilation_config.splitting_ops
|
||||
)
|
||||
assert (
|
||||
sequence_parallelism_pass.compilation_config.use_inductor_graph_partition
|
||||
== vllm_config.compilation_config.use_inductor_graph_partition
|
||||
)
|
||||
passes_for_backend: list[VllmInductorPass] = [
|
||||
noop_pass,
|
||||
sequence_parallelism_pass,
|
||||
]
|
||||
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
if enable_fusion:
|
||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||
passes_for_backend.append(fusion_pass)
|
||||
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
passes_for_backend.append(cleanup_pass)
|
||||
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
backend_no_func = TestBackend(*passes_for_backend)
|
||||
backend_func = TestBackend(*passes_for_backend, func_pass)
|
||||
|
||||
model = test_model_cls(hidden_size, hidden_size * 2, vllm_config=vllm_config)
|
||||
model = test_model_cls(hidden_size, hidden_size * 2)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size), dtype=dtype)
|
||||
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
compiled_model_no_func = torch.compile(model, backend=backend_no_func)
|
||||
compiled_model_no_func(hidden_states, residual)
|
||||
compiled_model_func = torch.compile(model, backend=backend_func)
|
||||
compiled_model_func(hidden_states, residual)
|
||||
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
assert sequence_parallelism_pass.matched_count == 1
|
||||
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
# In pre-nodes, all reduce should be there,
|
||||
# reduce scatter and all gather should not
|
||||
backend_no_func.check_before_ops(model.ops_in_model_before())
|
||||
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
# In post-nodes, reduce scatter and all gather should be there,
|
||||
# all reduce should not
|
||||
backend_no_func.check_after_ops(model.ops_in_model_after())
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None
|
||||
|
||||
# make sure the ops were all de-functionalized
|
||||
found = dict()
|
||||
for node in backend_func.graph_post_pass.nodes:
|
||||
for op in model.ops_in_model():
|
||||
if is_func(node, op):
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model())
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
import contextlib
|
||||
import pathlib
|
||||
from copy import deepcopy
|
||||
|
||||
from tblib import pickling_support
|
||||
|
||||
# ruff: noqa
|
||||
|
||||
# Install support for pickling exceptions so that we can nicely propagate
|
||||
# failures from tests running in a subprocess.
|
||||
# This should be run before any custom exception subclasses are defined.
|
||||
@ -40,7 +43,7 @@ from transformers import (
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
|
||||
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, SamplingParams, envs
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
@ -1070,6 +1073,101 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
|
||||
yield caplog
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def caplog_mp_fork():
|
||||
"""
|
||||
This fixture enables capturing logs from a forked MP subprocess.
|
||||
It should be used in conjunction with caplog_vllm.
|
||||
|
||||
By default, subprocess logs do not go through the parent process.
|
||||
We instead create a queue listener in the parent process which
|
||||
forwards logs to the logger's other handlers, and add a QueueHandler
|
||||
to the root logger. Forked subprocesses will inherit the root logger
|
||||
and pass their messages to the queue, which the listener will forward
|
||||
to the root logger, which can be captured by caplog.
|
||||
|
||||
Note that this workaround only works for fork; with spawn, the subprocess
|
||||
reinitializes logging and does not automatically inherit the queue.
|
||||
We'd have to manually pass the queue to the subprocess at the spawn point.
|
||||
See caplog_mp_spawn below.
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
import logging.handlers
|
||||
import multiprocessing as mp
|
||||
|
||||
logger_queue: mp.Queue[logging.LogRecord] = mp.Queue()
|
||||
logger = logging.getLogger()
|
||||
handlers = logger.handlers
|
||||
|
||||
# The listener works on a background thread, not inherited by the child.
|
||||
queue_listener = logging.handlers.QueueListener(logger_queue, *handlers)
|
||||
queue_listener.start()
|
||||
|
||||
# Add queue handler after creating the listener to avoid cycle
|
||||
logger.addHandler(logging.handlers.QueueHandler(logger_queue))
|
||||
yield
|
||||
queue_listener.stop()
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
class LogHolder:
|
||||
def __init__(self):
|
||||
self.text = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def caplog_mp_spawn(tmp_path, monkeypatch):
|
||||
"""
|
||||
This fixture enables capturing logs from a forked MP subprocess.
|
||||
It does not require caplog_vllm (but it only contains logs from the child).
|
||||
|
||||
By default, subprocess logs do not go through the parent process.
|
||||
We instead add a FileHandler to the config so the spawned child process
|
||||
writes its logs to a temp file.
|
||||
In the parent, we read the file and return the contents.
|
||||
|
||||
Note: this method could be extended to fork by either reconfiguring logging
|
||||
in the parent or using a SocketHandler:
|
||||
https://docs.python.org/3/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network # noqa: E501
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx(level: int | str):
|
||||
from vllm.logger import DEFAULT_LOGGING_CONFIG
|
||||
|
||||
config_path = tmp_path / "vllm_logging_config.json"
|
||||
log_path = tmp_path / "vllm.log"
|
||||
log_holder = LogHolder()
|
||||
|
||||
config = deepcopy(DEFAULT_LOGGING_CONFIG)
|
||||
if envs.VLLM_LOGGING_CONFIG_PATH:
|
||||
path = pathlib.Path(envs.VLLM_LOGGING_CONFIG_PATH)
|
||||
assert path.exists()
|
||||
config = json.loads(path.read_text())
|
||||
|
||||
config["loggers"]["vllm"]["handlers"] += ["vllm_file"]
|
||||
config["handlers"]["vllm_file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "vllm",
|
||||
"level": level,
|
||||
"filename": log_path.as_posix(),
|
||||
}
|
||||
|
||||
config_path.write_text(json.dumps(config))
|
||||
|
||||
with monkeypatch.context() as monkeypatch_ctx:
|
||||
monkeypatch_ctx.setenv("VLLM_LOGGING_CONFIG_PATH", config_path.as_posix())
|
||||
monkeypatch_ctx.setenv("VLLM_CONFIGURE_LOGGING", "1")
|
||||
yield log_holder
|
||||
|
||||
log_holder.text = log_path.read_text()
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def num_gpus_available():
|
||||
"""Get number of GPUs without initializing the CUDA context
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.quant_utils import FP8_DTYPE
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -70,38 +70,6 @@ def test_rms_norm(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_poly_norm(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = PolyNorm().to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
layer.bias.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
|
||||
ref_out = layer.forward_native(x)
|
||||
out = layer(x)
|
||||
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.poly_norm,
|
||||
(out, x, layer.weight.data, layer.bias.data, layer.variance_epsilon),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
|
||||
@ -103,7 +103,7 @@ def ref_dynamic_per_tensor_fp8_quant(
|
||||
.clamp(fp8_traits_min, fp8_traits_max)
|
||||
.to(FP8_DTYPE)
|
||||
)
|
||||
return ref_out, ref_scale.view((1,))
|
||||
return ref_out, ref_scale.view((1, 1))
|
||||
|
||||
|
||||
def native_w8a8_block_matmul(
|
||||
|
||||
@ -501,3 +501,49 @@ def test_streaming_complete_logs_full_text_content():
|
||||
assert call_args[1] == "test-streaming-full-text"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[5] == "streaming_complete"
|
||||
|
||||
|
||||
# Add vllm prefix to make sure logs go through the vllm logger
|
||||
test_logger = init_logger("vllm.test_logger")
|
||||
|
||||
|
||||
def mp_function(**kwargs):
|
||||
# This function runs in a subprocess
|
||||
|
||||
test_logger.warning("This is a subprocess: %s", kwargs.get("a"))
|
||||
test_logger.error("This is a subprocess error.")
|
||||
test_logger.debug("This is a subprocess debug message: %s.", kwargs.get("b"))
|
||||
|
||||
|
||||
def test_caplog_mp_fork(caplog_vllm, caplog_mp_fork):
|
||||
with caplog_vllm.at_level(logging.DEBUG), caplog_mp_fork():
|
||||
import multiprocessing
|
||||
|
||||
ctx = multiprocessing.get_context("fork")
|
||||
p = ctx.Process(
|
||||
target=mp_function,
|
||||
name=f"SubProcess{1}",
|
||||
kwargs={"a": "AAAA", "b": "BBBBB"},
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
assert "AAAA" in caplog_vllm.text
|
||||
assert "BBBBB" in caplog_vllm.text
|
||||
|
||||
|
||||
def test_caplog_mp_spawn(caplog_mp_spawn):
|
||||
with caplog_mp_spawn(logging.DEBUG) as log_holder:
|
||||
import multiprocessing
|
||||
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
p = ctx.Process(
|
||||
target=mp_function,
|
||||
name=f"SubProcess{1}",
|
||||
kwargs={"a": "AAAA", "b": "BBBBB"},
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
assert "AAAA" in log_holder.text
|
||||
assert "BBBBB" in log_holder.text
|
||||
|
||||
@ -6,6 +6,7 @@ import contextlib
|
||||
import copy
|
||||
import functools
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@ -15,7 +16,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterable
|
||||
from contextlib import ExitStack, contextmanager, suppress
|
||||
from multiprocessing import Process
|
||||
from pathlib import Path
|
||||
@ -1261,3 +1262,23 @@ def check_answers(
|
||||
frac_ok = numok / len(answer)
|
||||
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
|
||||
assert frac_ok >= accept_rate
|
||||
|
||||
|
||||
def flat_product(*iterables: Iterable[Any]):
|
||||
"""
|
||||
Flatten lists of tuples of the cartesian product.
|
||||
Useful when we want to avoid nested tuples to allow
|
||||
test params to be unpacked directly from the decorator.
|
||||
|
||||
Example:
|
||||
flat_product([(1, 2), (3, 4)], ["a", "b"]) ->
|
||||
[
|
||||
(1, 2, "a"),
|
||||
(1, 2, "b"),
|
||||
(3, 4, "a"),
|
||||
(3, 4, "b"),
|
||||
]
|
||||
"""
|
||||
for element in itertools.product(*iterables):
|
||||
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
|
||||
yield tuple(itertools.chain(*normalized))
|
||||
|
||||
@ -40,7 +40,7 @@ from vllm.utils import (
|
||||
unique_filepath,
|
||||
)
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
from ..utils import create_new_process_for_each_test, flat_product
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -771,3 +771,25 @@ def test_unique_filepath():
|
||||
paths.add(path)
|
||||
assert len(paths) == 10
|
||||
assert len(list(Path(temp_dir).glob("*.txt"))) == 10
|
||||
|
||||
|
||||
def test_flat_product():
|
||||
# Check regular itertools.product behavior
|
||||
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
|
||||
assert result1 == [
|
||||
(1, "a"),
|
||||
(1, "b"),
|
||||
(2, "a"),
|
||||
(2, "b"),
|
||||
(3, "a"),
|
||||
(3, "b"),
|
||||
]
|
||||
|
||||
# check that the tuples get flattened
|
||||
result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
|
||||
assert result2 == [
|
||||
(1, 2, "a", 5, 6),
|
||||
(1, 2, "b", 5, 6),
|
||||
(3, 4, "a", 5, 6),
|
||||
(3, 4, "b", 5, 6),
|
||||
]
|
||||
|
||||
@ -34,15 +34,21 @@ else
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B"
|
||||
)
|
||||
MODEL_NAMES=${MODEL_NAMES:-}
|
||||
if [[ -n "$MODEL_NAMES" ]]; then
|
||||
MODELS=("$MODEL_NAMES")
|
||||
else
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B"
|
||||
)
|
||||
fi
|
||||
|
||||
# Number of prefill and decode instances to create
|
||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
@ -130,7 +136,7 @@ run_tests_for_model() {
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
@ -171,7 +177,7 @@ run_tests_for_model() {
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $DECODER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
@ -200,7 +206,7 @@ run_tests_for_model() {
|
||||
done
|
||||
|
||||
# Build the command for the proxy server with all the hosts and ports
|
||||
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
|
||||
PROXY_CMD="python3 ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
|
||||
|
||||
# Add all prefill hosts and ports
|
||||
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
|
||||
@ -219,7 +225,7 @@ run_tests_for_model() {
|
||||
|
||||
# Run lm eval for this model
|
||||
echo "Running tests for $model_name"
|
||||
TEST_MODEL=$model_name python -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
|
||||
TEST_MODEL=$model_name python3 -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
|
||||
|
||||
# Clean up before running next model
|
||||
cleanup_instances
|
||||
|
||||
@ -12,7 +12,12 @@ FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
|
||||
# Model-specific expected values
|
||||
EXPECTED_VALUES = {"Qwen/Qwen3-0.6B": 0.41, "deepseek-ai/deepseek-vl2-small": 0.59}
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen3-0.6B": 0.41,
|
||||
"deepseek-ai/deepseek-vl2-small": 0.59,
|
||||
"deepseek-ai/deepseek-vl2-tiny": 0.19,
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
|
||||
}
|
||||
|
||||
SIMPLE_PROMPT = (
|
||||
"The best part about working on vLLM is that I got to meet so many people across "
|
||||
|
||||
@ -76,7 +76,8 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
|
||||
# For prefiller instances
|
||||
parser.add_argument(
|
||||
|
||||
40
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
40
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Utility to run integration tests sequentially with varying TP configurations.
|
||||
SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
|
||||
|
||||
# Define test configurations
|
||||
configs=(
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
)
|
||||
|
||||
run_tests() {
|
||||
local label=$1
|
||||
local extra_env=$2
|
||||
|
||||
echo "=== Running tests (${label}) ==="
|
||||
for cfg in "${configs[@]}"; do
|
||||
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
|
||||
# Use 'env' to safely set variables without eval
|
||||
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
|
||||
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "✅ All ${label} tests passed!"
|
||||
}
|
||||
|
||||
# Run tests
|
||||
run_tests "default backend" ""
|
||||
|
||||
# Check if FLASHINFER is set (non-empty)
|
||||
if [[ -n "${FLASHINFER:-}" ]]; then
|
||||
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
else
|
||||
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
||||
fi
|
||||
@ -339,18 +339,6 @@ def fused_add_rms_norm(
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def poly_norm(
|
||||
out: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
epsilon: float,
|
||||
) -> None:
|
||||
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||
input_contiguous = input.contiguous()
|
||||
torch.ops._C.poly_norm(out, input_contiguous, weight, bias, epsilon)
|
||||
|
||||
|
||||
def apply_repetition_penalties_torch(
|
||||
logits: torch.Tensor,
|
||||
prompt_mask: torch.Tensor,
|
||||
@ -1507,7 +1495,7 @@ def scaled_fp8_quant(
|
||||
output, input, scale, scale_ub
|
||||
)
|
||||
else:
|
||||
scale = torch.empty(1, device=input.device, dtype=torch.float32)
|
||||
scale = torch.empty((1, 1), device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
assert scale.numel() == 1, f"{scale.shape}"
|
||||
|
||||
@ -17,10 +17,14 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -41,11 +45,8 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
|
||||
class BasePattern:
|
||||
@ -669,33 +670,24 @@ class AllReduceRMSNormPattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rms_result = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
return [input, rms_result, weight]
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
|
||||
):
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rms_result,
|
||||
input=allreduce_output,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
# rms_result, allreduce_output
|
||||
return rms[1], allreduce_output
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, rms_result: torch.Tensor, weight: torch.Tensor
|
||||
):
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@ -733,29 +725,19 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self):
|
||||
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [
|
||||
residual,
|
||||
input,
|
||||
weight,
|
||||
]
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
# input, residual
|
||||
return rms[1], rms[2]
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
@ -779,6 +761,18 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
# Same pattern, but only return the output and not residual
|
||||
# (helpful for end of graph where residual is not used again)
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern),
|
||||
first_return_only(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
@ -799,60 +793,37 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
rmsnorm_result = torch.empty(
|
||||
[1, 8, 4], device=self.device, dtype=self.dtype
|
||||
)
|
||||
quant_result = torch.empty(
|
||||
[1, 8, 4], device=self.device, dtype=self.quant_dtype
|
||||
)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, rmsnorm_result, quant_result, weight, scale]
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rmsnorm_result,
|
||||
input=all_reduce,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP8_QUANT_OP,
|
||||
result=quant_result,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return quant_out_tuple[1], all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
@ -892,64 +863,42 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty(
|
||||
[4, 4], device=self.device, dtype=self.quant_dtype
|
||||
)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
weight,
|
||||
scale,
|
||||
]
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
|
||||
fused_add_rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP8_QUANT_OP,
|
||||
result=quant_result,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
|
||||
return quant, res
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
@ -986,14 +935,11 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
rmsnorm_result = torch.empty(
|
||||
[1, 16, 16], device=self.device, dtype=self.dtype
|
||||
)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
@ -1001,36 +947,21 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [
|
||||
input,
|
||||
rmsnorm_result,
|
||||
quant_result,
|
||||
weight,
|
||||
input_global_scale,
|
||||
output_scale,
|
||||
]
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
rmsnorm_result: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=rmsnorm_result,
|
||||
input=all_reduce,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rmsnorm_out_tuple[1],
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
)
|
||||
@ -1040,13 +971,13 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@ -1090,6 +1021,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def get_inputs():
|
||||
@ -1121,28 +1053,17 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
input_global_scale: torch.Tensor,
|
||||
):
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
|
||||
fused_add_rmsnorm_out_tuple = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=allreduce_output,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=fused_add_rmsnorm_out_tuple[1],
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return (
|
||||
quant_out_tuple[1],
|
||||
fused_add_rmsnorm_out_tuple[2],
|
||||
quant_out_tuple[2],
|
||||
)
|
||||
return quant_out_tuple[1], residual, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
|
||||
@ -9,7 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -92,13 +93,19 @@ class RMSNormQuantPattern:
|
||||
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
assert key.quant in QUANT_OPS, f"unsupported quantization scheme {key.quant}"
|
||||
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(key.quant)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(self, epsilon: float, quant_dtype: torch.dtype, symmetric=True):
|
||||
@ -112,34 +119,18 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale
|
||||
)
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -153,12 +144,11 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
@ -175,33 +165,27 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale
|
||||
)
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -216,11 +200,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@ -248,34 +230,18 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at1[1], scale=scale, scale_ub=None
|
||||
)
|
||||
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor):
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return at2[1], at2[2]
|
||||
return self.quant_matcher(result_rms)
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
result_rms: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
def replacement(input: torch.Tensor, weight: torch.Tensor):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -290,18 +256,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # result_rms
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@ -323,34 +281,21 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=at[1], scale=scale, scale_ub=None
|
||||
)
|
||||
def pattern(input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor):
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
# result, residual, scale
|
||||
return at1[1], at[2], at1[2]
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
):
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
@ -365,18 +310,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
inputs = [
|
||||
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||
empty_bf16(5, 4), # input
|
||||
empty_bf16(5, 4), # residual
|
||||
empty_bf16(1, 5), # weight
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
@ -396,23 +333,25 @@ class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
|
||||
@ -2,9 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
@ -20,7 +22,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
from .fx_utils import is_func
|
||||
from .inductor_pass import enable_fake_mode
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -66,9 +70,13 @@ class AttentionQuantPattern(ABC):
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(process_fx, trace_fn):
|
||||
def wrap_trace_fn(trace_fn, *process_fx_fns: Callable[[fx.GraphModule], None]):
|
||||
def wrapped(*args, **kwargs):
|
||||
return process_fx(trace_fn(*args, **kwargs))
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@ -77,7 +85,20 @@ class AttentionQuantPattern(ABC):
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule):
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass):
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
@ -108,6 +129,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass):
|
||||
def pattern(
|
||||
@ -115,7 +137,6 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
at1 = auto_functionalized(
|
||||
@ -131,17 +152,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP, result=output_quant, input=attn_out_view, scale=scale
|
||||
)
|
||||
return at2[1]
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
):
|
||||
# attn output in quant_dtype
|
||||
@ -164,13 +182,10 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # q
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # k
|
||||
self.empty(5, self.num_heads, self.head_size, dtype=self.dtype), # v
|
||||
self.empty(
|
||||
5, self.num_heads, self.head_size, dtype=self.dtype
|
||||
), # attn_output
|
||||
self.empty_quant(5, self.num_heads * self.head_size), # quant_output
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
]
|
||||
|
||||
@ -179,7 +194,9 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
@ -279,7 +296,9 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
@ -6,7 +6,7 @@ from collections.abc import Iterable, Iterator
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target) -> bool:
|
||||
@ -64,7 +64,17 @@ def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
def find_op_nodes(op: OpOverload, graph: fx.Graph) -> Iterator[fx.Node]:
|
||||
# Also handles op overload packets and finds all overloads
|
||||
def find_op_nodes(
|
||||
op: OpOverload | OpOverloadPacket, graph: fx.Graph
|
||||
) -> Iterator[fx.Node]:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
for overload in op.overloads():
|
||||
overload_op = getattr(op, overload)
|
||||
yield from find_op_nodes(overload_op, graph)
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
if not op._schema.is_mutable:
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
|
||||
208
vllm/compilation/matcher_utils.py
Normal file
208
vllm/compilation/matcher_utils.py
Normal file
@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool):
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
self.enabled = enabled
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args, **kws):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args, **kws):
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kws):
|
||||
return self.forward(*args, **kws)
|
||||
|
||||
def empty(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
|
||||
|
||||
def empty_f32(self, *args, **kws):
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
RMS_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(self, epsilon: float, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
|
||||
def inputs(self):
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
_, result, residual = auto_functionalized(
|
||||
RMS_ADD_OP,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return result, scale
|
||||
else:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale)
|
||||
|
||||
def make_scale(self, input: torch.Tensor):
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
scale_shape = (
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
if self.quant_key.scale.static:
|
||||
return [input, self.empty_f32(1, 1)]
|
||||
|
||||
return [input]
|
||||
@ -22,6 +22,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
@ -5,7 +5,7 @@ import functools
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_env_var
|
||||
@ -88,27 +88,30 @@ class PostGradPassManager(CustomGraphPass):
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.enable_noop:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
if self.pass_config.enable_sequence_parallelism:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.enable_async_tp:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if self.pass_config.enable_fi_allreduce_fusion:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
if self.pass_config.enable_fusion:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
if self.pass_config.enable_attn_fusion:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
# [HACK: Bug with Inductor graph partition and torch.compile cache]
|
||||
# In PyTorch 2.9, torch.compile has a bug where the graph
|
||||
|
||||
@ -128,7 +128,8 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *",
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
file=f,
|
||||
)
|
||||
|
||||
|
||||
@ -1403,8 +1403,15 @@ class EngineArgs:
|
||||
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
|
||||
)
|
||||
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size
|
||||
if self.data_parallel_backend == "ray" and (
|
||||
envs.VLLM_RAY_DP_PACK_STRATEGY == "span"
|
||||
):
|
||||
# Data parallel size defaults to 1 if DP ranks are spanning
|
||||
# multiple nodes
|
||||
data_parallel_size_local = 1
|
||||
else:
|
||||
# Otherwise local DP size defaults to global DP size if not set
|
||||
data_parallel_size_local = self.data_parallel_size
|
||||
|
||||
# DP address, used in multi-node case for torch distributed group
|
||||
# and ZMQ sockets.
|
||||
|
||||
@ -110,7 +110,7 @@ class EngineClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def stop_profile(self) -> None:
|
||||
"""Start profiling the engine"""
|
||||
"""Stop profiling the engine"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -139,7 +139,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_MOE_DP_CHUNK_SIZE: int = 256
|
||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||
VLLM_RAY_DP_PACK_STRATEGY: str = "strict"
|
||||
VLLM_RAY_DP_PACK_STRATEGY: Literal["strict", "fill", "span"] = "strict"
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_MXFP4_USE_MARLIN: bool | None = None
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
@ -1039,6 +1039,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# for non-master nodes, allocate as many DP ranks as can fit;
|
||||
# - "strict":
|
||||
# allocate exactly data-parallel-size-local DP ranks to each picked node;
|
||||
# - "span":
|
||||
# Should be used only when a single DP rank requires multiple nodes.
|
||||
# allocate one DP rank over as many nodes as required for set world_size;
|
||||
# This environment variable is ignored if data-parallel-backend is not Ray.
|
||||
"VLLM_RAY_DP_PACK_STRATEGY": lambda: os.getenv(
|
||||
"VLLM_RAY_DP_PACK_STRATEGY", "strict"
|
||||
|
||||
@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@ -46,6 +46,11 @@ def is_rocm_aiter_moe_enabled() -> bool:
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def use_mxfp4_aiter_moe() -> bool:
|
||||
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
|
||||
return (
|
||||
@ -487,6 +492,8 @@ def rocm_aiter_fused_experts(
|
||||
assert quant_config.w1_scale is not None
|
||||
assert quant_config.w2_scale is not None
|
||||
quant_method = QuantMethod.BLOCK_128x128.value
|
||||
elif quant_config.use_fp8_w8a8 and quant_config.per_out_ch_quant:
|
||||
quant_method = QuantMethod.PER_TOKEN.value
|
||||
elif quant_config.use_fp8_w8a8:
|
||||
# Currently only per tensor quantization method is enabled.
|
||||
quant_method = QuantMethod.PER_TENSOR.value
|
||||
|
||||
@ -58,22 +58,6 @@ def fused_add_rms_norm(
|
||||
return x, residual
|
||||
|
||||
|
||||
def poly_norm(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
out = torch.empty_like(x)
|
||||
ops.poly_norm(
|
||||
out,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
variance_epsilon,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def rocm_aiter_rms_norm_impl(
|
||||
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
|
||||
) -> torch.Tensor:
|
||||
@ -178,14 +162,11 @@ class RMSNorm(CustomOp):
|
||||
self.variance_size_override = (
|
||||
None if var_hidden_size == hidden_size else var_hidden_size
|
||||
)
|
||||
weight_dtype = dtype or torch.get_default_dtype()
|
||||
self.has_weight = has_weight
|
||||
if dtype is not None:
|
||||
self.weight = torch.ones(hidden_size, dtype=dtype)
|
||||
else:
|
||||
self.weight = torch.ones(hidden_size)
|
||||
self.weight = torch.ones(hidden_size, dtype=weight_dtype)
|
||||
if self.has_weight:
|
||||
self.weight = nn.Parameter(self.weight)
|
||||
weight_dtype = self.weight.data.dtype
|
||||
|
||||
if current_platform.is_rocm():
|
||||
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
|
||||
@ -195,46 +176,68 @@ class RMSNorm(CustomOp):
|
||||
with_fused_add=True, dtype=weight_dtype
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
x: torch.Tensor,
|
||||
variance_epsilon: float,
|
||||
hidden_size: int,
|
||||
orig_dtype: torch.dtype,
|
||||
weight: torch.Tensor | None = None,
|
||||
residual: torch.Tensor | None = None,
|
||||
variance_size_override: int | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
# residual promoted f16->f32 automatically,
|
||||
# otherwise Inductor eliminates the casts to and from f16,
|
||||
# increasing memory usage (and complicating pattern matching)
|
||||
x = x + residual
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
if x.shape[-1] != hidden_size:
|
||||
raise ValueError(
|
||||
f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
|
||||
)
|
||||
|
||||
if variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, :variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + variance_epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
hidden_size = x.shape[-1]
|
||||
if hidden_size != self.hidden_size:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be "
|
||||
f"{self.hidden_size}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
if self.variance_size_override is None:
|
||||
x_var = x
|
||||
else:
|
||||
if hidden_size < self.variance_size_override:
|
||||
raise ValueError(
|
||||
"Expected hidden_size to be at least "
|
||||
f"{self.variance_size_override}, but found: {hidden_size}"
|
||||
)
|
||||
|
||||
x_var = x[:, :, : self.variance_size_override]
|
||||
|
||||
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
||||
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype)
|
||||
if self.has_weight:
|
||||
x = x * self.weight
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
return self.forward_static(
|
||||
x,
|
||||
self.variance_epsilon,
|
||||
self.hidden_size,
|
||||
x.dtype,
|
||||
self.weight.data if self.has_weight else None,
|
||||
residual,
|
||||
self.variance_size_override,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
@ -366,53 +369,6 @@ class GemmaRMSNorm(CustomOp):
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
@CustomOp.register("poly_norm")
|
||||
class PolyNorm(CustomOp):
|
||||
"""Polynomial normalization.
|
||||
|
||||
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
|
||||
where w_n is the learned weight and b is the bias.
|
||||
Refer to https://arxiv.org/html/2411.03884v1
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
||||
self.bias = torch.nn.Parameter(torch.zeros(1))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
|
||||
"""
|
||||
|
||||
orig_dtype = x.dtype
|
||||
x_float = x.to(torch.float32)
|
||||
output = (
|
||||
self.weight[0] * self._norm(x_float**3)
|
||||
+ self.weight[1] * self._norm(x_float**2)
|
||||
+ self.weight[2] * self._norm(x_float)
|
||||
+ self.bias
|
||||
)
|
||||
return output.to(orig_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"""
|
||||
Layer Normalization.
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
use_mxfp4_aiter_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
@ -341,7 +342,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=self.weight_qscheme == "per_channel",
|
||||
per_act_token_quant=self.input_qscheme == "per_channel",
|
||||
per_out_ch_quant=self.weight_qscheme == "per_channel",
|
||||
)
|
||||
|
||||
def apply(
|
||||
@ -472,22 +474,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"not implemented. Please open an issue."
|
||||
)
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
self.emulate = True
|
||||
self.emulate = not current_platform.supports_mx() or not (
|
||||
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
)
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4/MXFP6 "
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
else:
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4/MXFP6 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
"The current mode supports native MoE MXFP4 computation"
|
||||
)
|
||||
|
||||
def get_packed_dim(self, dim: int, quant_dtype: str):
|
||||
@ -568,6 +570,24 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.emulate:
|
||||
return
|
||||
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
# Pre-shuffle weight scales
|
||||
s0, s1, _ = layer.w13_weight_scale.shape
|
||||
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
||||
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
||||
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
||||
|
||||
s0, s1, _ = layer.w2_weight_scale.shape
|
||||
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
@ -611,8 +631,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -628,17 +646,44 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
if not self.emulate:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
aiter_acts = {
|
||||
ActivationType.No.name.lower(): ActivationType.No,
|
||||
ActivationType.Silu.name.lower(): ActivationType.Silu,
|
||||
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
|
||||
}
|
||||
assert activation in aiter_acts, (
|
||||
f"Aiter CK fp4 MoE doesn't support activation {activation}"
|
||||
)
|
||||
out = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=aiter_acts[activation],
|
||||
doweight_stage1=False,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return out
|
||||
|
||||
@ -464,8 +464,16 @@ class Fp8LinearOp:
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = x_scale.numel() == 1
|
||||
# Must have dim() conditions
|
||||
# In per-token quant scenario, when the number of token is 1,
|
||||
# the scale will only have 1 elements.
|
||||
# Without checking the dim(),
|
||||
# we cannot distingushes between per-tensor and per-token quant.
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = (weight_scale.numel() == 1) and weight_scale.dim() < 2
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
|
||||
@ -735,9 +735,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
if do_sample_frames:
|
||||
# here video_fps is the fps of the sampled video, and
|
||||
# metadata["fps"] refers to the fps of the original video.
|
||||
video_fps = sampled_fps if sampled_fps else video_processor.fps
|
||||
sampled_fps = sampled_fps if sampled_fps else video_processor.fps
|
||||
total_num_frames = metadata["total_num_frames"]
|
||||
num_frames = int(total_num_frames / metadata["fps"] * video_fps)
|
||||
num_frames = int(total_num_frames / metadata["fps"] * sampled_fps)
|
||||
num_frames = min(
|
||||
min(
|
||||
max(num_frames, video_processor.min_frames),
|
||||
|
||||
@ -350,6 +350,14 @@ class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
|
||||
dummy_inputs=Qwen3VLDummyInputsBuilder,
|
||||
)
|
||||
class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3VLForConditionalGeneration, self).__init__()
|
||||
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
|
||||
@ -376,6 +384,11 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
|
||||
)
|
||||
# Whether to include the gate_up_proj mapping is determined by
|
||||
# the language model.
|
||||
self.packed_modules_mapping = (
|
||||
self.packed_modules_mapping | self.language_model.packed_modules_mapping
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
@ -558,6 +558,19 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
|
||||
if (
|
||||
self.dcp_world_size > 1
|
||||
and self.__class__.reorder_batch_threshold > 1
|
||||
and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder"
|
||||
):
|
||||
logger.warning_once(
|
||||
"DCP is enabled but not FlashAttnMLA is used. "
|
||||
"Set query_len_support back to SINGLE_ONLY "
|
||||
"and reorder_batch_threshold back to 1."
|
||||
)
|
||||
self.__class__.query_len_support = QueryLenSupport.SINGLE_ONLY
|
||||
self.__class__.reorder_batch_threshold = 1
|
||||
|
||||
# Don't try to access the runner on AMD
|
||||
if self.aot_schedule:
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
|
||||
@ -345,6 +345,7 @@ class CoreEngineActorManager:
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
placement_groups: list[PlacementGroup] = []
|
||||
local_dp_ranks: list[int] = []
|
||||
|
||||
dp_master_ip_key = f"node:{dp_master_ip}"
|
||||
nodes = sorted(
|
||||
available_resources.values(), key=lambda x: dp_master_ip_key not in x
|
||||
@ -355,9 +356,25 @@ class CoreEngineActorManager:
|
||||
dp_master_ip,
|
||||
)
|
||||
device_str = current_platform.ray_device_key
|
||||
n_node_devices: list[int] = [
|
||||
int(node_resources[device_str])
|
||||
for node_resources in nodes
|
||||
if device_str in node_resources
|
||||
]
|
||||
assert n_node_devices, f"No {device_str} found in Ray cluster."
|
||||
max_device_per_node = max(n_node_devices)
|
||||
|
||||
pack_strategy = envs.VLLM_RAY_DP_PACK_STRATEGY
|
||||
_supported_pack_strategies = ("strict", "fill", "span")
|
||||
if pack_strategy not in _supported_pack_strategies:
|
||||
raise ValueError(
|
||||
f"{envs.VLLM_RAY_DP_PACK_STRATEGY} is not supported. "
|
||||
"Make sure to set `VLLM_RAY_DP_PACK_STRATEGY` "
|
||||
f"to one of {_supported_pack_strategies}"
|
||||
)
|
||||
|
||||
all2all_backend = vllm_config.parallel_config.all2all_backend
|
||||
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
|
||||
if pack_strategy == "fill" and (
|
||||
all2all_backend == "deepep_high_throughput"
|
||||
or all2all_backend == "deepep_low_latency"
|
||||
):
|
||||
@ -367,12 +384,42 @@ class CoreEngineActorManager:
|
||||
"does not guarantee that. "
|
||||
"Please use VLLM_RAY_DP_PACK_STRATEGY=strict instead."
|
||||
)
|
||||
logger.info(
|
||||
"Using '%s' DP packing strategy based on VLLM_RAY_DP_PACK_STRATEGY",
|
||||
envs.VLLM_RAY_DP_PACK_STRATEGY,
|
||||
)
|
||||
strict_local_size = envs.VLLM_RAY_DP_PACK_STRATEGY == "strict"
|
||||
|
||||
if pack_strategy in ("strict", "fill"):
|
||||
placement_strategy = "STRICT_PACK"
|
||||
else:
|
||||
placement_strategy = "PACK"
|
||||
assert world_size > max_device_per_node, (
|
||||
f"World size {world_size} is smaller than the "
|
||||
"maximum number of devices per node "
|
||||
f"{max_device_per_node}. Make sure to set "
|
||||
"`VLLM_RAY_DP_PACK_STRATEGY` to `strict` or `fill`"
|
||||
)
|
||||
|
||||
# if we need multiple nodes per dp group, we require for now that
|
||||
# available nodes are homogenous
|
||||
assert set(n_node_devices) == {max_device_per_node}, (
|
||||
f"Nodes are not homogenous, {nodes}"
|
||||
)
|
||||
assert world_size % max_device_per_node == 0, (
|
||||
f"For multi-node data parallel groups, world_size ({world_size}) must "
|
||||
f"be a multiple of number of devices per node ({max_device_per_node})."
|
||||
)
|
||||
assert len(n_node_devices) * max_device_per_node >= world_size * dp_size, (
|
||||
f"Not enough total available nodes ({len(n_node_devices)}) "
|
||||
f"and devices per node ({max_device_per_node}) "
|
||||
f"to satisfy required world size {world_size} and data parallel size "
|
||||
f"{dp_size}"
|
||||
)
|
||||
assert dp_size_local == 1, (
|
||||
f"data-parallel-size-local {dp_size_local} should be set as the "
|
||||
"default (1) for VLLM_RAY_DP_PACK_STRATEGY=span. "
|
||||
"The actual data-parallel-size-local will be auto determined."
|
||||
)
|
||||
|
||||
# bundles collected for a single DP rank from multiple nodes,
|
||||
# for "span" pack strategy
|
||||
collected_bundles = []
|
||||
for node_resources in nodes:
|
||||
node_ip_keys = [
|
||||
key
|
||||
@ -386,14 +433,14 @@ class CoreEngineActorManager:
|
||||
node_ip_key = node_ip_keys[0]
|
||||
node_ip = node_ip_key.split(":")[1]
|
||||
|
||||
# For now, each DP rank can only be assigned to one node
|
||||
# TODO(rui): support allocating a single DP rank
|
||||
# to multiple nodes
|
||||
dp_size_available = (
|
||||
int(node_resources[device_str]) // world_size
|
||||
if device_str in node_resources
|
||||
else 0
|
||||
)
|
||||
n_device_on_node = int(node_resources.get(device_str, 0))
|
||||
if pack_strategy == "span" and n_device_on_node != 0:
|
||||
# Strictly speaking,
|
||||
# dp_size_available = n_device_on_node / world_size
|
||||
# and is a fraction, but we use 1 for easier processing
|
||||
dp_size_available = 1
|
||||
else:
|
||||
dp_size_available = n_device_on_node // world_size
|
||||
|
||||
if node_ip == dp_master_ip:
|
||||
if dp_size_available < dp_size_local:
|
||||
@ -405,7 +452,7 @@ class CoreEngineActorManager:
|
||||
dp_size_available,
|
||||
)
|
||||
dp_size_to_allocate = dp_size_local
|
||||
elif strict_local_size:
|
||||
elif pack_strategy == "strict":
|
||||
if dp_size_available < dp_size_local:
|
||||
logger.info(
|
||||
"Skipping node %s as %s DP ranks could not fit, "
|
||||
@ -417,15 +464,31 @@ class CoreEngineActorManager:
|
||||
continue
|
||||
dp_size_to_allocate = dp_size_local
|
||||
else:
|
||||
# for "pack_strategy" in "fill" and "span"
|
||||
# we always take everything that's available
|
||||
dp_size_to_allocate = dp_size_available
|
||||
|
||||
for i in range(dp_size_to_allocate):
|
||||
bundles = [{device_str: 1.0, "node:" + node_ip: 0.001}] * world_size + [
|
||||
{"CPU": 1.0}
|
||||
]
|
||||
device_bundle = [{device_str: 1.0, "node:" + node_ip: 0.001}]
|
||||
if pack_strategy == "span":
|
||||
collected_bundles += device_bundle * n_device_on_node
|
||||
assert len(collected_bundles) <= world_size, (
|
||||
"collected_bundles should be <= world_size, "
|
||||
f"but got {len(collected_bundles)=} and {world_size=}"
|
||||
)
|
||||
|
||||
# we only create a placement group if we collected enough devices
|
||||
if len(collected_bundles) < world_size:
|
||||
continue
|
||||
|
||||
bundles = collected_bundles + [{"CPU": 1.0}]
|
||||
collected_bundles = []
|
||||
else:
|
||||
bundles = device_bundle * world_size + [{"CPU": 1.0}]
|
||||
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
strategy=placement_strategy,
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user