From cd0bbf5de2f0558327b094d81b45d6d5a8dfd44f Mon Sep 17 00:00:00 2001 From: Param Date: Tue, 30 Sep 2025 22:19:53 -0400 Subject: [PATCH] Fix INT8 quantization error on Blackwell GPUs (SM100+) (#25935) Signed-off-by: padg9912 Signed-off-by: yewentao256 --- csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp | 5 ++++- docs/features/quantization/int8.md | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index 3af59267bd60c..2204a49257b08 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp +++ b/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -25,7 +25,10 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, if constexpr (!std::is_same_v) { int8_func(c, a, b, a_scales, b_scales, bias); } else { - TORCH_CHECK(false, "Int8 not supported for this architecture"); + int32_t version_num = get_sm_version_num(); + TORCH_CHECK( + false, "Int8 not supported on SM", version_num, + ". Use FP8 quantization instead, or run on older arch (SM < 100)."); } } } else { diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 247d0cbdd3f14..af3650e701ad0 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -6,7 +6,11 @@ This quantization method is particularly useful for reducing model size while ma Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415). !!! note - INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell). + INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper). + +!!! warning + **Blackwell GPU Limitation**: INT8 is not supported on compute capability >= 100 (e.g., RTX 6000 Blackwell). + Use [FP8 quantization](fp8.md) instead, or run on Hopper/Ada/Ampere architectures. ## Prerequisites