Rewrite C++ meta funcs to Python (#28595)

Signed-off-by: Jane Xu <janeyx@meta.com>
This commit is contained in:
Jane (Yuan) Xu 2025-11-13 11:52:50 -05:00 committed by GitHub
parent d3387750f1
commit 06c4873d95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 33 deletions

View File

@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out;
}
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}

View File

@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out;
}
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}

View File

@ -1174,13 +1174,50 @@ def gptq_marlin_repack(
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
# gptq_marlin
if hasattr(torch.ops._C, "gptq_marlin_repack"):
@register_fake("_C::gptq_marlin_repack")
def _gptq_marlin_repack_fake(
b_q_weight: torch.Tensor,
perm: torch.Tensor,
size_k: torch.SymInt,
size_n: torch.SymInt,
num_bits: int,
) -> torch.Tensor:
pack_factor = 32 // num_bits
marlin_tile_size = 16
return torch.empty(
(size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
dtype=b_q_weight.dtype,
device=b_q_weight.device,
)
# awq_marlin
def awq_marlin_repack(
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
if hasattr(torch.ops._C, "awq_marlin_repack"):
@register_fake("_C::awq_marlin_repack")
def _awq_marlin_repack_fake(
b_q_weight: torch.Tensor,
size_k: torch.SymInt,
size_n: torch.SymInt,
num_bits: int,
) -> torch.Tensor:
pack_factor = 32 // num_bits
marlin_tile_size = 16
return torch.empty(
(size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
dtype=b_q_weight.dtype,
device=b_q_weight.device,
)
def gptq_marlin_moe_repack(
b_q_weight: torch.Tensor,
perm: torch.Tensor,