mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 10:35:52 +08:00
Signed-off-by: kaln27 <liaojuncheng123@foxmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
51 lines
2.6 KiB
C++
51 lines
2.6 KiB
C++
#pragma once
|
|
|
|
#include <torch/all.h>
|
|
|
|
namespace vllm {
|
|
|
|
void cutlass_scaled_mm_sm90_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
std::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
std::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
torch::Tensor const& azp_adj,
|
|
std::optional<torch::Tensor> const& azp,
|
|
std::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
|
torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales);
|
|
|
|
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
std::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales,
|
|
std::optional<torch::Tensor> const& bias);
|
|
|
|
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
|
torch::Tensor const& a,
|
|
torch::Tensor const& b,
|
|
torch::Tensor const& a_scales,
|
|
torch::Tensor const& b_scales);
|
|
} // namespace vllm
|