mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 12:31:20 +08:00
Rewrite C++ meta funcs to Python (#28595)
Signed-off-by: Jane Xu <janeyx@meta.com>
This commit is contained in:
parent
d3387750f1
commit
06c4873d95
@ -247,22 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|
||||||
c10::SymInt size_k, c10::SymInt size_n,
|
|
||||||
int64_t num_bits) {
|
|
||||||
int const pack_factor = 32 / num_bits;
|
|
||||||
auto options = torch::TensorOptions()
|
|
||||||
.dtype(b_q_weight.dtype())
|
|
||||||
.device(b_q_weight.device());
|
|
||||||
return torch::empty_symint(
|
|
||||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
|
||||||
options);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("awq_marlin_repack", &awq_marlin_repack);
|
m.impl("awq_marlin_repack", &awq_marlin_repack);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
|
||||||
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
|
||||||
}
|
|
||||||
|
|||||||
@ -321,22 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
|
|
||||||
torch::Tensor& perm, c10::SymInt size_k,
|
|
||||||
c10::SymInt size_n, int64_t num_bits) {
|
|
||||||
int const pack_factor = 32 / num_bits;
|
|
||||||
auto options = torch::TensorOptions()
|
|
||||||
.dtype(b_q_weight.dtype())
|
|
||||||
.device(b_q_weight.device());
|
|
||||||
return torch::empty_symint(
|
|
||||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
|
||||||
options);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
|
||||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1174,13 +1174,50 @@ def gptq_marlin_repack(
|
|||||||
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits)
|
||||||
|
|
||||||
|
|
||||||
# gptq_marlin
|
if hasattr(torch.ops._C, "gptq_marlin_repack"):
|
||||||
|
|
||||||
|
@register_fake("_C::gptq_marlin_repack")
|
||||||
|
def _gptq_marlin_repack_fake(
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
perm: torch.Tensor,
|
||||||
|
size_k: torch.SymInt,
|
||||||
|
size_n: torch.SymInt,
|
||||||
|
num_bits: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pack_factor = 32 // num_bits
|
||||||
|
marlin_tile_size = 16
|
||||||
|
return torch.empty(
|
||||||
|
(size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
|
||||||
|
dtype=b_q_weight.dtype,
|
||||||
|
device=b_q_weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# awq_marlin
|
||||||
def awq_marlin_repack(
|
def awq_marlin_repack(
|
||||||
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits)
|
||||||
|
|
||||||
|
|
||||||
|
if hasattr(torch.ops._C, "awq_marlin_repack"):
|
||||||
|
|
||||||
|
@register_fake("_C::awq_marlin_repack")
|
||||||
|
def _awq_marlin_repack_fake(
|
||||||
|
b_q_weight: torch.Tensor,
|
||||||
|
size_k: torch.SymInt,
|
||||||
|
size_n: torch.SymInt,
|
||||||
|
num_bits: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pack_factor = 32 // num_bits
|
||||||
|
marlin_tile_size = 16
|
||||||
|
return torch.empty(
|
||||||
|
(size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor),
|
||||||
|
dtype=b_q_weight.dtype,
|
||||||
|
device=b_q_weight.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def gptq_marlin_moe_repack(
|
def gptq_marlin_moe_repack(
|
||||||
b_q_weight: torch.Tensor,
|
b_q_weight: torch.Tensor,
|
||||||
perm: torch.Tensor,
|
perm: torch.Tensor,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user