From 06c4873d959feb0d4cb062ef17cdd0dd09dbf10f Mon Sep 17 00:00:00 2001 From: "Jane (Yuan) Xu" <31798555+janeyx99@users.noreply.github.com> Date: Thu, 13 Nov 2025 11:52:50 -0500 Subject: [PATCH] Rewrite C++ meta funcs to Python (#28595) Signed-off-by: Jane Xu --- .../gptq_marlin/awq_marlin_repack.cu | 16 -------- .../gptq_marlin/gptq_marlin_repack.cu | 16 -------- vllm/_custom_ops.py | 39 ++++++++++++++++++- 3 files changed, 38 insertions(+), 33 deletions(-) diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index 8ba617a9e6555..e607107b3e77c 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, return out; } -torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, - c10::SymInt size_k, c10::SymInt size_n, - int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - return torch::empty_symint( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); -} - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("awq_marlin_repack", &awq_marlin_repack); } - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { - m.impl("awq_marlin_repack", &awq_marlin_repack_meta); -} diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 7c2d089a70d95..ad80d51ece94e 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, return out; } -torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, - torch::Tensor& perm, c10::SymInt size_k, - c10::SymInt size_n, int64_t num_bits) { - int const pack_factor = 32 / num_bits; - auto options = torch::TensorOptions() - .dtype(b_q_weight.dtype()) - .device(b_q_weight.device()); - return torch::empty_symint( - {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, - options); -} - TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("gptq_marlin_repack", &gptq_marlin_repack); } - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { - m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta); -} diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7d70c01cefbb6..096266c9764e8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1174,13 +1174,50 @@ def gptq_marlin_repack( return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) -# gptq_marlin +if hasattr(torch.ops._C, "gptq_marlin_repack"): + + @register_fake("_C::gptq_marlin_repack") + def _gptq_marlin_repack_fake( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: torch.SymInt, + size_n: torch.SymInt, + num_bits: int, + ) -> torch.Tensor: + pack_factor = 32 // num_bits + marlin_tile_size = 16 + return torch.empty( + (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + + +# awq_marlin def awq_marlin_repack( b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +if hasattr(torch.ops._C, "awq_marlin_repack"): + + @register_fake("_C::awq_marlin_repack") + def _awq_marlin_repack_fake( + b_q_weight: torch.Tensor, + size_k: torch.SymInt, + size_n: torch.SymInt, + num_bits: int, + ) -> torch.Tensor: + pack_factor = 32 // num_bits + marlin_tile_size = 16 + return torch.empty( + (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + + def gptq_marlin_moe_repack( b_q_weight: torch.Tensor, perm: torch.Tensor,