Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2024-12-22 22:39:03 -08:00
parent 0420fb2c7b
commit 3fdbd8e2f5
2 changed files with 20 additions and 14 deletions

View File

@ -23,31 +23,34 @@ __global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
torch::Tensor& matrix_tgt, int64_t n) {
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
// CPU overheads. We assume that the caller will pass the correct inputs.
// Check tensor properties
TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
auto src_sizes = matrix_src.sizes();
auto diff_sizes = matrix_diff.sizes();
auto tgt_sizes = matrix_tgt.sizes();
TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
int64_t N = src_sizes[0];
int64_t M = src_sizes[1];
TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
"matrix_tgt must have same shape as matrix_src");
// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
// "matrix_tgt must have same shape as matrix_src");
TORCH_CHECK(n <= N, "n must be <= N");
// TORCH_CHECK(n <= N, "n must be <= N");
const int* d_matrix_src = matrix_src.data_ptr<int>();
const int* d_matrix_diff = matrix_diff.data_ptr<int>();

View File

@ -256,6 +256,9 @@ def copy_subranges(
tgt_matrix: torch.Tensor,
num_subranges: int,
) -> None:
# NOTE(woosuk): We use `torch.ops._C.copy_subranges.default` instead of
# `torch.ops._C.copy_subranges` to avoid unnecessary CPU overheads from
# the dispatcher.
torch.ops._C.copy_subranges.default(src_matrix, diff_matrix, tgt_matrix,
num_subranges)