diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index edf7aff1abaac..8b80362583eec 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -6,11 +6,11 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, const int64_t rows_per_block); torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, - const c10::optional& in_bias, + const std::optional& in_bias, const int64_t CuCount); void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, - const c10::optional& in_bias, at::Tensor& out_c, + const std::optional& in_bias, at::Tensor& out_c, const at::Tensor& scale_a, const at::Tensor& scale_b, const int64_t CuCount); diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index bf2fe169c7114..2ef579a1b7537 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -1271,7 +1271,7 @@ int mindiv(int N, int div1, int div2) { } torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, - const c10::optional& in_bias, + const std::optional& in_bias, const int64_t CuCount) { auto M_in = in_a.size(0); auto K_in = in_a.size(1); @@ -1729,7 +1729,7 @@ __global__ void wvSplitKQ_hf_(const int K, const int Kp, const int M, #endif // defined(__HIP__MI3XX__) TODO: Add NAVI support void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, - const c10::optional& in_bias, at::Tensor& out_c, + const std::optional& in_bias, at::Tensor& out_c, const at::Tensor& scale_a, const at::Tensor& scale_b, const int64_t CuCount) { static c10::ScalarType kFp8Type = is_fp8_ocp()