Fix INT8 quantization error on Blackwell GPUs (SM100+) (#25935)

Signed-off-by: padg9912 <phone.and.desktop@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Param 2025-09-30 22:19:53 -04:00 committed by yewentao256
parent 2b6b859916
commit cd0bbf5de2
2 changed files with 9 additions and 2 deletions

View File

@ -25,7 +25,10 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(false, "Int8 not supported for this architecture");
int32_t version_num = get_sm_version_num();
TORCH_CHECK(
false, "Int8 not supported on SM", version_num,
". Use FP8 quantization instead, or run on older arch (SM < 100).");
}
}
} else {

View File

@ -6,7 +6,11 @@ This quantization method is particularly useful for reducing model size while ma
Please visit the HF collection of [quantized INT8 checkpoints of popular LLMs ready to use with vLLM](https://huggingface.co/collections/neuralmagic/int8-llms-for-vllm-668ec32c049dca0369816415).
!!! note
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper, Blackwell).
INT8 computation is supported on NVIDIA GPUs with compute capability > 7.5 (Turing, Ampere, Ada Lovelace, Hopper).
!!! warning
**Blackwell GPU Limitation**: INT8 is not supported on compute capability >= 100 (e.g., RTX 6000 Blackwell).
Use [FP8 quantization](fp8.md) instead, or run on Hopper/Ada/Ampere architectures.
## Prerequisites