vllm/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Joonchen Liau 9e5552aa13
[NVIDIA] Support Cutlass w8a8 FP8 for Blackwell Geforce GPUs (sm120) (#17280)
Signed-off-by: kaln27 <liaojuncheng123@foxmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
2025-07-02 06:47:19 -06:00

51 lines
2.6 KiB
C++

#pragma once
#include <torch/all.h>
namespace vllm {
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm