[Kernel/Quant] Remove AQLM (#22943)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Michael Goin 2025-08-16 15:38:21 -04:00 committed by GitHub
parent 3253ae765e
commit 4fc722eca4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 0 additions and 1534 deletions

View File

@ -121,7 +121,6 @@ fi
if [[ $commands == *" kernels/quantization"* ]]; then
commands="${commands} \
--ignore=kernels/quantization/test_int8_quant.py \
--ignore=kernels/quantization/test_aqlm.py \
--ignore=kernels/quantization/test_machete_mm.py \
--ignore=kernels/quantization/test_block_fp8.py \
--ignore=kernels/quantization/test_block_int8.py \

View File

@ -286,7 +286,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"

View File

@ -1,345 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import sys
from typing import Optional
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight,
generic_dequantize_gemm,
get_int_dtype,
optimized_dequantize_gemm,
)
from vllm.utils import FlexibleArgumentParser
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def torch_mult(
# [..., in_features]
input: torch.Tensor,
weights: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
) -> torch.Tensor:
output = F.linear(input, weights)
return output
def dequant_out_scale(
# [..., in_features]
input: torch.Tensor,
# [num_out_groups, num_in_groups, num_codebooks]
codes: torch.IntTensor,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_weight_scale(
# [..., in_features]
input: torch.Tensor,
# [num_out_groups, num_in_groups, num_codebooks]
codes: torch.IntTensor,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_no_scale(
# [..., in_features]
input: torch.Tensor,
# [num_out_groups, num_in_groups, num_codebooks]
codes: torch.IntTensor,
# [num_codebooks, codebook_size, out_group_size, in_group_size]
codebooks: torch.Tensor,
# [num_out_groups, 1, 1, 1]
scales: torch.Tensor,
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias)
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
n = int(parts.sum().item())
device = torch.device("cuda:0")
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(
-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device,
)
codebooks = torch.randn(
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device,
)
count = 0
for index in range(16):
for i in range(8):
for book in range(nbooks):
codebooks[book, index, 0, i] = count * (10**book)
count += 1
print("codes shape", codes.shape)
for i in range(16):
for book in range(nbooks):
codes[0, i, book] = i
codes[0, -i, book] = i
weights = dequantize_weight(codes, codebooks, None)
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
print("weights shape:", weights.shape)
print("weights2 shape:", weights2.shape)
print("weights are:", weights)
print("weights2 are:", weights2)
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
print("last 128 weights are", weights[0, -128:])
print("last 128 weights2 are:", weights2[0, -128:])
def main():
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
parser.add_argument(
"--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
)
parser.add_argument(
"--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)",
)
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)",
)
# Parse the arguments
args = parser.parse_args()
# Extract values
nbooks = args.nbooks
bits = args.bits
if args.test:
dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
return
# Otherwise, benchmark.
methods = [
ops.aqlm_gemm,
dequant_out_scale,
generic_dequantize_gemm,
optimized_dequantize_gemm,
dequant_weight_scale,
torch_mult,
dequant_no_scale,
]
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
print(f"writing benchmarks to file {filename}")
with open(filename, "w") as f:
sys.stdout = f
print("m | k | n | n parts", end="")
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
print("")
# These are reasonable prefill sizes.
ksandpartions = (
(4096, (4096, 4096, 4096)),
(4096, (4096,)),
(4096, (11008, 11008)),
(11008, (4096,)),
)
# reasonable ranges for m.
for m in [
1,
2,
4,
8,
10,
12,
14,
16,
24,
32,
48,
52,
56,
64,
96,
112,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]:
print(f"{m}", file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
num_calls = 100
# warmup.
for method in methods:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
n = parts.sum().item()
print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
for method in methods:
best_time_us = 1e20
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
kernel_dur_us = 1000 * kernel_dur_ms
if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us
print(f" | {kernel_dur_us:.0f}", end="")
print("")
def run_timing(
num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
) -> float:
n = int(parts.sum().item())
device = torch.device("cuda:0")
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(
-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device,
)
codebooks = torch.randn(
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device,
)
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
# for comparison to just a pytorch mult.
weights = torch.randn((n, k), dtype=torch.float16, device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if method is torch_mult:
for i in range(num_calls):
torch_mult(input, weights, scales)
else:
for i in range(num_calls):
method(input, codes, codebooks, scales, parts, None)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())

View File

@ -154,15 +154,6 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias);
torch::Tensor aqlm_dequant(
const torch::Tensor& codes, const torch::Tensor& codebooks,
const std::vector<int64_t>& codebook_partition_sizes);
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,

View File

@ -1,597 +0,0 @@
/*
* Modified by Neural Magic
* Adapted from https://github.com/Vahe1994/AQLM
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include <cstdlib>
namespace vllm {
namespace aqlm {
__global__ void Code1x16MatVec(
const int4* __restrict__ A, const int4* __restrict__ B,
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
const int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
__shared__ int4 sh_b[32 * 9];
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*)&codebook[enc[i]]));
half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code2x8MatVec(
const int4* __restrict__ A, const int4* __restrict__ B,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
extern __shared__ int4 sh[];
int4* sh_b = sh;
int4* sh_code = sh_b + 32 * 9;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
half2* a0 =
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++)
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code1x16Dequant(
const int4* __restrict__ A, int4* __restrict__ C,
const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long, sums to m.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*)&codebook[enc[i]]));
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
__global__ void Code2x8Dequant(
const int4* __restrict__ A, int4* __restrict__ C,
const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
extern __shared__ int4 sh[];
int4* sh_code = sh;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
half2* a0 =
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
const int THREAD_M = 16;
void code1x16_matvec_cuda(const void* __restrict__ A,
const void* __restrict__ B, void* __restrict__ C,
const void* __restrict__ codebook, int prob_m,
int prob_k, const int4 codebook_a_sizes,
const int codebook_stride) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
prob_k, codebook_a_sizes, codebook_stride);
}
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
void* __restrict__ C,
const void* __restrict__ codebook, int prob_m,
int prob_k, const int4 codebook_a_sizes,
const int codebook_stride) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaFuncSetAttribute(Code2x8MatVec,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
prob_k, codebook_a_sizes, codebook_stride);
}
void code1x16_dequant_cuda(
const void* __restrict__ A, void* __restrict__ C,
const void* __restrict__ codebook, int prob_m, int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
// most 3 long.
codebook_stride // as int4.
);
}
// Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda(
const void* __restrict__ A, void* __restrict__ C,
const void* __restrict__ codebook, int prob_m, int prob_k,
const int4
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaFuncSetAttribute(Code2x8Dequant,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
codebook_a_sizes, codebook_stride);
}
int codebook_stride(const torch::Tensor& codebooks) {
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
}
void code1x16_matvec(
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes // cumulative sizes of A spanning each
// codebook, at most 3 long.
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
codebook_stride(codebook));
}
torch::Tensor code1x16_matmat(const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty(
{flat_input.size(0), out_features},
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codebook_a_sizes);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
torch::Tensor& C, const torch::Tensor& codebook,
const int4 codebook_a_sizes) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
2 * codebook_stride(codebook));
}
torch::Tensor code2x8_matmat(const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty(
{flat_input.size(0), out_features},
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codebook_a_sizes);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
// Accumulate the partition sizes.
int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x;
size_t i = 0;
int last = 0;
assert(codebook_partition_sizes.size() <= 4);
for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
*cumulative_size = codebook_partition_sizes[i] + last;
last = *cumulative_size;
}
// fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size) {
*cumulative_size = last * 10;
}
return cumulative_sizes;
}
} // namespace aqlm
} // namespace vllm
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias) {
int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16)) {
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
cumulative_sizes, bias);
}
if (nbooks == 2 && entries == (1 << 8)) {
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
cumulative_sizes, bias);
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {};
}
torch::Tensor aqlm_dequant(
const torch::Tensor& codes, const torch::Tensor& codebooks,
const std::vector<int64_t>& codebook_partition_sizes) {
int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
int const entries = codebooks.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
int rows = codes.size(1);
int cols = codes.size(0);
auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0);
assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
codebook_partition_sizes.end(), 0));
auto weights = torch::empty({out_features, in_features},
torch::TensorOptions()
.dtype(codebooks.dtype())
.device(codebooks.device()));
if (nbooks == 1 && entries == (1 << 16)) {
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
codebooks.data_ptr(), out_features,
in_features, cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return weights;
}
if (nbooks == 2 && entries == (1 << 8)) {
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
codebooks.data_ptr(), out_features,
in_features, cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return weights;
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {};
}

View File

@ -207,21 +207,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AQLM.
ops.def(
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
"-> Tensor",
{stride_tag});
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
// Decompression method for AQLM.
ops.def(
"aqlm_dequant(Tensor codes, Tensor codebooks, "
"int[] codebook_partition_sizes) -> Tensor",
{stride_tag});
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
// Quantized GEMM for AWQ.
ops.def(
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "

View File

@ -17,7 +17,6 @@ th {
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |

View File

@ -24,7 +24,6 @@ def fix_case(text: str) -> str:
"llm": "LLM",
"mae": "MAE",
"tpu": "TPU",
"aqlm": "AQLM",
"gguf": "GGUF",
"lora": "LoRA",
"rlhf": "RLHF",

View File

@ -52,20 +52,6 @@ Try it yourself with the following argument:
### Quantization
#### AQLM
vLLM supports models that are quantized using AQLM.
Try one yourself by passing one of the following models to the `--model` argument:
- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf`
- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf`
- `ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf`
- `ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf`
- `BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf`
> Some of these models are likely to be too large for a single GPU. You can split them across multiple GPUs by setting `--tensor-parallel-size` to the number of required GPUs.
#### GGUF
vLLM supports models that are quantized using GGUF.

View File

@ -31,10 +31,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
]
if all:
if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223

View File

@ -1,40 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
def test_aqlm_dequant_opcheck():
codes = torch.randint(-32768,
32767, (22016, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((2, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
codebook_partition_sizes = [11008, 11008]
opcheck(torch.ops._C.aqlm_dequant,
(codes, codebooks, codebook_partition_sizes))
def test_aqlm_gemm_opcheck():
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
codes = torch.randint(-32768,
32767, (12288, 512, 1),
device='cuda',
dtype=torch.int16)
codebooks = torch.rand((3, 65536, 1, 8),
device='cuda',
dtype=torch.float16)
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
codebook_partition_sizes = [4096, 4096, 4096]
bias = None
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, None))
opcheck(torch.ops._C.aqlm_gemm,
(input, codes, codebooks, scales, codebook_partition_sizes, bias))

View File

@ -1,68 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.quantization.utils import is_quant_method_supported
from vllm.platforms import current_platform
# These ground truth generations were generated using `transformers==4.38.1
# aqlm==1.1.0 torch==2.2.0`
# and the below code:
# ```python
# from transformers import AutoTokenizer, AutoModelForCausalLM
# model_id = "ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"
# quantized_model = AutoModelForCausalLM.from_pretrained(model_id,
# torch_dtype="auto", device_map="cuda").cuda()
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# outputs = []
# for prompt in example_prompts:
# input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
# hf_outputs = quantized_model.generate(input_ids, max_new_tokens=32)
# outputs.append(tokenizer.decode(hf_outputs[0][input_ids.shape[1]:]))
# print(outputs)
# ```
ground_truth_generations = [
'\n### Features\n\n- **High-throughput**: v',
'The major milestones in the development of artificial intelligence from '
'195',
'Compare and contrast artificial intelligence with human intelligence in '
'terms of processing information. The',
'Explain the difference between supervised and unsupervised learning.'
'\nExplain',
'Write a short story about a robot that dreams for the first time. The',
'Analyze the impact of the COVID-19 pandemic on global economic',
'The Mona Lisa is a painting by Leonardo da Vinci, and it',
'The early bird catches the worm.\nThe early bird catches the'
]
@pytest.mark.skipif(not is_quant_method_supported("aqlm")
or current_platform.is_rocm()
or not current_platform.is_cuda(),
reason="AQLM is not supported on this GPU type.")
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("num_logprobs", [1])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# loop through the prompts to compare against the ground truth generations
for prompt_idx in range(len(example_prompts)):
vllm_output_ids, vllm_output_str, vllm_logprobs = vllm_outputs[
prompt_idx]
print("Prompt: ", repr(example_prompts[prompt_idx]))
print("Reference output:", repr(ground_truth_generations[prompt_idx]))
print("Output output: ", repr(vllm_output_str))
assert vllm_output_str == ground_truth_generations[prompt_idx]

View File

@ -476,32 +476,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
dtype=input.dtype,
device=input.device).sum(0)
@register_fake("_C::aqlm_gemm")
def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: list[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
out_features = codes.size(0) * codebooks.size(2)
flat_input = input.reshape((-1, input.size(-1)))
flat_output = torch.empty((flat_input.size(0), out_features),
dtype=input.dtype,
device=input.device)
output_sizes = list(input.shape)
output_sizes.pop()
output_sizes.append(-1)
return flat_output.reshape(tuple(output_sizes))
@register_fake("_C::aqlm_dequant")
def _aqlm_dequant_fake(
codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: list[int]) -> torch.Tensor:
in_features = codes.size(1) * 8
out_features = codes.size(0)
return torch.empty((out_features, in_features),
dtype=codebooks.dtype,
device=codebooks.device)
@register_fake("_C::machete_mm")
def machete_mm_fake(
a: torch.Tensor,
@ -957,21 +931,6 @@ def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
sf_offsets)
# aqlm
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
codebooks: torch.Tensor, scales: torch.Tensor,
codebook_partition_sizes: list[int],
bias: Optional[torch.Tensor]) -> torch.Tensor:
return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales,
codebook_partition_sizes, bias)
def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
codebook_partition_sizes: list[int]) -> torch.Tensor:
return torch.ops._C.aqlm_dequant(codes, codebooks,
codebook_partition_sizes)
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int,

View File

@ -692,8 +692,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
@ -781,13 +779,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
@ -1081,8 +1072,6 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
@ -1204,13 +1193,6 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(

View File

@ -7,7 +7,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationMethods = Literal[
"aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
@ -88,7 +87,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .aqlm import AQLMConfig
from .auto_round import AutoRoundConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
@ -120,7 +118,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,

View File

@ -1,376 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
import math
from typing import Any, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)
def dequantize_weight(codes: torch.Tensor,
codebooks: torch.Tensor,
scales: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape
[*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code,
[num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be
broadcastble with
[*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape
[*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = \
codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size,
device=codes.device) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets,
codebooks.flatten(0, 1).flatten(-2, -1),
mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
# * in_group_size]
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) +
[num_out_groups, num_in_groups, out_group_size, in_group_size])
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
scales)
return reconstructed_weight_groupwise.swapaxes(
-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
def dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
dequantized_weight = dequantize_weight(
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
# Generic dequantization, slow but flexible.
def generic_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: list[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
num_outputs = len(output_partition_sizes)
# break the inputs and codebooks apart then combine the outputs.
# Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
# multiply at the end.
num_codebooks = codebooks.shape[0] // num_outputs
assert (scales.shape[0] == codes.shape[0])
assert (sum(output_partition_sizes) == scales.shape[0])
output_offset = 0
codebooks_offset = 0
for output_size in output_partition_sizes:
shard_output = dequantize_gemm(
input, codes.narrow(0, output_offset, output_size),
codebooks.narrow(0, codebooks_offset, num_codebooks),
scales.narrow(0, output_offset, output_size), None
if bias is None else bias.narrow(0, output_offset, output_size))
output_slice = output.narrow(-1, output_offset, output_size)
assert (output_slice.shape == shard_output.shape)
output_slice.copy_(shard_output)
output_offset += output_size
codebooks_offset += num_codebooks
return output
# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
# at 6 and 9 times faster than the generic version above, respectively.
def optimized_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: list[int],
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
# scaling the output is fastest, so we do that when possible.
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
class AQLMConfig(QuantizationConfig):
"""Config class for AQLM.
Reference: https://github.com/Vahe1994/AQLM
"""
def __init__(
self,
in_group_size: int,
nbits_per_codebook: int,
num_codebooks: int,
out_group_size: int,
) -> None:
super().__init__()
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks
self.out_group_size = out_group_size
# out_group_size > 1 is untested, and probably won't work as-is.
assert (self.out_group_size == 1)
self.pack_factor = (self.in_group_size * self.out_group_size)
def __repr__(self) -> str:
return (f"AQLMConfig(in_group_size={self.in_group_size}, "
f"nbits_per_codebook={self.nbits_per_codebook}, "
f"num_codebooks={self.num_codebooks}, "
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "aqlm"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
out_group_size = cls.get_from_keys(config, ["out_group_size"])
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM.
Args:
quant_config: The AQLM quantization config.
"""
def __init__(self, quant_config: AQLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del output_size # Unused.
del input_size # Unused.
if params_dtype != torch.half:
raise ValueError("Only half is currently supported by aqlm")
if input_size_per_partition % self.quant_config.in_group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.out_group_size != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
codes = Parameter(
torch.empty(
# There could actually be two pack factors, one along input and
# one along output, but we don't currently support
# out_group_size, and only the one along output needs to be
# marked with "packed_dim" in order for QKVLinear to work.
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
self.quant_config.num_codebooks,
dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
),
requires_grad=False,
)
set_weight_attrs(
codes,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
codebooks = Parameter(
torch.empty(
self.quant_config.num_codebooks * len(output_partition_sizes),
2**self.quant_config.nbits_per_codebook,
self.quant_config.out_group_size,
self.quant_config.in_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
codebooks,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata": True,
"output_partition_sizes": output_partition_sizes
},
)
scales = Parameter(
torch.empty(
(
output_size_per_partition //
self.quant_config.out_group_size,
1,
1,
1,
),
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 0,
"packed_dim": 0,
"pack_factor": self.quant_config.out_group_size
},
)
layer.register_parameter("codes", codes)
set_weight_attrs(codes, extra_weight_attrs)
layer.register_parameter("codebooks", codebooks)
set_weight_attrs(codebooks, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
codebooks = layer.codebooks
codes = layer.codes
scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
[])
nbooks = codes.shape[2]
ingroups = codebooks.shape[3]
outgroups = codebooks.shape[2]
bits = codebooks.shape[1]
# We support these formats with dedicated gemm and decompression
# kernels.
if ingroups == 8 and outgroups == 1 and (
(bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
# thresholds determined by timings on an A6000, one GPU
use_gemv = math.prod(x.shape[:-1]) <= 6
return ops.aqlm_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
) if use_gemv else optimized_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)
# fall back all unoptimized formats
return generic_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)