vllm/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu
lyrisz 97e3dda84b
[Perf] SM100 - add swap AB optimization to CUTLASS FP8 GEMM (#27284)
Signed-off-by: Faqin Zhong <faqin.zhong@gmail.com>
Co-authored-by: Faqin Zhong <zhofaqin@amazon.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-11-04 07:49:25 -08:00

24 lines
971 B
Plaintext

#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
namespace vllm {
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm100_fp8_epilogue<true>(out, a, b, a_scales,
b_scales, *bias);
} else {
return cutlass_scaled_mm_sm100_fp8_epilogue<false>(out, a, b, a_scales,
b_scales);
}
}
} // namespace vllm