mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 09:24:28 +08:00
Rewrite C++ meta funcs to Python (#28595)
Signed-off-by: Jane Xu <janeyx@meta.com>
This commit is contained in:
parent
d3387750f1
commit
06c4873d95
@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
c10::SymInt size_k, c10::SymInt size_n,
|
||||
int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
return torch::empty_symint(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
||||
}
|
||||
|
||||
@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
return out;
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
||||
torch::Tensor& perm, c10::SymInt size_k,
|
||||
c10::SymInt size_n, int64_t num_bits) {
|
||||
int const pack_factor = 32 / num_bits;
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
return torch::empty_symint(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||
}
|
||||
|
||||
@ -1174,13 +1174,50 @@ def gptq_marlin_repack(
|
||||
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
# gptq_marlin
|
||||
if hasattr(torch.ops._C, "gptq_marlin_repack"):
|
||||
|
||||
@register_fake("_C::gptq_marlin_repack")
|
||||
def _gptq_marlin_repack_fake(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
pack_factor = 32 // num_bits
|
||||
marlin_tile_size = 16
|
||||
return torch.empty(
|
||||
(size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
|
||||
dtype=b_q_weight.dtype,
|
||||
device=b_q_weight.device,
|
||||
)
|
||||
|
||||
|
||||
# awq_marlin
|
||||
def awq_marlin_repack(
|
||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||
|
||||
|
||||
if hasattr(torch.ops._C, "awq_marlin_repack"):
|
||||
|
||||
@register_fake("_C::awq_marlin_repack")
|
||||
def _awq_marlin_repack_fake(
|
||||
b_q_weight: torch.Tensor,
|
||||
size_k: torch.SymInt,
|
||||
size_n: torch.SymInt,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
pack_factor = 32 // num_bits
|
||||
marlin_tile_size = 16
|
||||
return torch.empty(
|
||||
(size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
|
||||
dtype=b_q_weight.dtype,
|
||||
device=b_q_weight.device,
|
||||
)
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user