mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 08:24:59 +08:00
655 lines
21 KiB
Python
655 lines
21 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Copyright © 2025, Oracle and/or its affiliates.
|
|
|
|
import os
|
|
from collections.abc import Callable
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fused_moe.config import (
|
|
FusedMoEConfig,
|
|
FusedMoEQuantConfig,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
|
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|
|
from vllm.model_executor.layers.linear import (
|
|
LinearBase,
|
|
LinearMethodBase,
|
|
set_weight_attrs,
|
|
)
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
apply_rtn_marlin_linear,
|
|
marlin_make_workspace_new,
|
|
)
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
logger = init_logger(__name__)
|
|
"""By default, use 8 bit as target precision, but it can be
|
|
overridden by setting the RTN_NUM_BITS envvar
|
|
"""
|
|
NUM_BITS = os.getenv("RTN_NUM_BITS", "8")
|
|
"""By default, use group size of 128 parameters, but it can be
|
|
overridden by setting the RTN_GROUP_SIZE envvar
|
|
"""
|
|
GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128")
|
|
"""Global Marlin workspace shared by all modules
|
|
"""
|
|
workspace = None
|
|
|
|
|
|
class RTNConfig(QuantizationConfig):
|
|
"""Config class for RTN."""
|
|
|
|
def __init__(
|
|
self,
|
|
weight_bits: int = int(NUM_BITS),
|
|
group_size: int = int(GROUP_SIZE),
|
|
) -> None:
|
|
self.weight_bits = weight_bits
|
|
self.group_size = group_size
|
|
|
|
if self.weight_bits != 4 and self.weight_bits != 8:
|
|
raise ValueError(
|
|
"Currently, only 4-bit or 8-bit weight quantization is "
|
|
f"supported for RTN, but got {self.weight_bits} bits."
|
|
)
|
|
|
|
self.quant_type = (
|
|
scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"RTNConfig(weight_bits={self.weight_bits}, group_size={self.group_size})"
|
|
)
|
|
|
|
@classmethod
|
|
def get_name(cls) -> QuantizationMethods:
|
|
return "rtn"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.bfloat16, torch.half]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 80
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> list[str]:
|
|
return []
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
return cls(weight_bits, group_size)
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional["QuantizeMethodBase"]:
|
|
if isinstance(layer, LinearBase):
|
|
return RTNLinearMethod(self)
|
|
elif isinstance(layer, FusedMoE):
|
|
return RTNMoEMethod(self, layer.moe_config)
|
|
return None
|
|
|
|
|
|
class RTNTensor:
|
|
"""A wrapper over Tensor that enables quantization on-the-fly by
|
|
overloading the copy_ method.
|
|
"""
|
|
|
|
def __init__(
|
|
self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
|
|
) -> None:
|
|
self.data = data
|
|
self.scale = scale
|
|
self.quant_config = quant_config
|
|
|
|
def narrow(self, dim, start, length):
|
|
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
|
return RTNTensor(
|
|
self.data.narrow(dim, start // factor, length // factor),
|
|
self.scale.narrow(dim, start, length),
|
|
self.quant_config,
|
|
)
|
|
|
|
def __getitem__(self, key):
|
|
return RTNTensor(self.data[key], self.scale[key], self.quant_config)
|
|
|
|
@property
|
|
def shape(self):
|
|
shape = self.data.shape
|
|
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
|
batch_present = len(shape) == 3
|
|
if batch_present:
|
|
return torch.Size((shape[0], shape[1] * factor, shape[2]))
|
|
else:
|
|
return torch.Size((shape[0] * factor, shape[1]))
|
|
|
|
def copy_(self, loaded_weight: torch.Tensor) -> None:
|
|
qweight, weight_scale = rtn_quantize(
|
|
loaded_weight.cuda(),
|
|
self.quant_config.weight_bits,
|
|
self.quant_config.group_size,
|
|
)
|
|
|
|
self.data.copy_(qweight)
|
|
self.scale.data.copy_(weight_scale)
|
|
|
|
|
|
class RTNParameter(Parameter):
|
|
"""A wrapper over Parameter that returns RTNTensor (a wrapper over Tensor)
|
|
when its data is accessed. We need this wrapper for the data loading phase
|
|
only, so we can intercept a weight copying function (torch.Tensor.copy_)
|
|
and apply quantization on-the-fly.
|
|
"""
|
|
|
|
def __new__(cls, data: torch.Tensor, **kwargs):
|
|
return super().__new__(cls, data=data, requires_grad=False)
|
|
|
|
def __init__(
|
|
self, data: torch.Tensor, scale: torch.Tensor, quant_config: RTNConfig
|
|
) -> None:
|
|
self.scale = scale
|
|
self.quant_config = quant_config
|
|
|
|
@property
|
|
def data(self):
|
|
return RTNTensor(super().data, self.scale, self.quant_config)
|
|
|
|
|
|
class RTNLinearMethod(LinearMethodBase):
|
|
"""Linear method for RTN.
|
|
|
|
Args:
|
|
quant_config: The RTN quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: RTNConfig):
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
num_groups_per_col = (
|
|
input_size_per_partition // self.quant_config.group_size
|
|
if self.quant_config.group_size != -1
|
|
else 1
|
|
)
|
|
|
|
scale = Parameter(
|
|
torch.empty(
|
|
output_size_per_partition, num_groups_per_col, dtype=params_dtype
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
|
|
|
weight = RTNParameter(
|
|
data=torch.empty(
|
|
output_size_per_partition // factor,
|
|
input_size_per_partition,
|
|
dtype=torch.uint8,
|
|
),
|
|
scale=scale,
|
|
quant_config=self.quant_config,
|
|
)
|
|
|
|
layer.register_parameter("weight", weight)
|
|
set_weight_attrs(
|
|
weight,
|
|
{
|
|
**extra_weight_attrs,
|
|
"input_dim": 1,
|
|
"output_dim": 0,
|
|
},
|
|
)
|
|
|
|
layer.register_parameter("scale", scale)
|
|
layer.output_size_per_partition = output_size_per_partition
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
"""Repack weights and scales for Marlin kernels."""
|
|
weight_bits = self.quant_config.weight_bits
|
|
|
|
weight, scale = repack_weights(layer.weight, layer.scale, weight_bits)
|
|
|
|
replace_parameter(layer, "weight", weight)
|
|
replace_parameter(layer, "scale", scale)
|
|
|
|
init_workspace(layer.weight.device)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
return apply_rtn_marlin_linear(
|
|
input=x,
|
|
weight=layer.weight,
|
|
weight_scale=layer.scale,
|
|
workspace=workspace,
|
|
quant_type=self.quant_config.quant_type,
|
|
output_size_per_partition=layer.output_size_per_partition,
|
|
input_size_per_partition=layer.input_size_per_partition,
|
|
bias=bias,
|
|
)
|
|
|
|
|
|
class RTNMoEMethod(FusedMoEMethodBase):
|
|
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
|
|
super().__init__(moe)
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
factor = 1 if self.quant_config.weight_bits == 8 else 2
|
|
|
|
# Fused gate_up_proj (column parallel)
|
|
num_groups_per_col = (
|
|
hidden_size // self.quant_config.group_size
|
|
if self.quant_config.group_size != -1
|
|
else 1
|
|
)
|
|
w13_scale = Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
num_groups_per_col,
|
|
dtype=params_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_scale", w13_scale)
|
|
|
|
w13_weight = RTNParameter(
|
|
data=torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition // factor,
|
|
hidden_size,
|
|
dtype=torch.uint8,
|
|
),
|
|
scale=w13_scale,
|
|
quant_config=self.quant_config,
|
|
)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
# down_proj (row parallel)
|
|
num_groups_per_col = (
|
|
intermediate_size_per_partition // self.quant_config.group_size
|
|
if self.quant_config.group_size != -1
|
|
else 1
|
|
)
|
|
w2_scale = Parameter(
|
|
torch.zeros(
|
|
num_experts, hidden_size, num_groups_per_col, dtype=params_dtype
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w2_scale", w2_scale)
|
|
|
|
w2_weight = RTNParameter(
|
|
data=torch.empty(
|
|
num_experts,
|
|
hidden_size // factor,
|
|
intermediate_size_per_partition,
|
|
dtype=torch.uint8,
|
|
),
|
|
scale=w2_scale,
|
|
quant_config=self.quant_config,
|
|
)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
"""Repack weights and scales for Marlin kernels."""
|
|
weight_bits = self.quant_config.weight_bits
|
|
|
|
w13_weight, w13_scale = repack_weights(
|
|
layer.w13_weight, layer.w13_scale, weight_bits
|
|
)
|
|
replace_parameter(layer, "w13_weight", w13_weight)
|
|
replace_parameter(layer, "w13_scale", w13_scale)
|
|
|
|
w2_weight, w2_scale = repack_weights(
|
|
layer.w2_weight, layer.w2_scale, weight_bits
|
|
)
|
|
replace_parameter(layer, "w2_weight", w2_weight)
|
|
replace_parameter(layer, "w2_scale", w2_scale)
|
|
|
|
init_workspace(layer.w13_weight.device)
|
|
|
|
def get_fused_moe_quant_config(
|
|
self, layer: torch.nn.Module
|
|
) -> FusedMoEQuantConfig | None:
|
|
return None
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: int | None = None,
|
|
num_expert_group: int | None = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: torch.Tensor | None = None,
|
|
custom_routing_function: Callable | None = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: torch.Tensor | None = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: torch.Tensor | None = None,
|
|
logical_to_physical_map: torch.Tensor | None = None,
|
|
logical_replica_count: torch.Tensor | None = None,
|
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
assert self.fused_experts is None
|
|
|
|
if enable_eplb:
|
|
raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.")
|
|
|
|
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
indices_type=self.topk_indices_dtype,
|
|
)
|
|
|
|
return fused_marlin_moe(
|
|
x,
|
|
layer.w13_weight,
|
|
layer.w2_weight,
|
|
getattr(layer, "w13_bias", None),
|
|
getattr(layer, "w2_bias", None),
|
|
layer.w13_scale,
|
|
layer.w2_scale,
|
|
router_logits,
|
|
topk_weights,
|
|
topk_ids,
|
|
quant_type_id=self.quant_config.quant_type.id,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
workspace=workspace,
|
|
)
|
|
|
|
|
|
def rtn_quantize(
|
|
tensor: torch.Tensor, num_bits: int, group_size: int
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Quantize a tensor using per-group static scaling factor.
|
|
|
|
Args:
|
|
tensor: The input tensor.
|
|
num_bits: Target precision for the result (supported values are
|
|
8 or 4).
|
|
group_size: Quantization granularity.
|
|
If equal to -1, each row in the input tensor is treated
|
|
as one group.
|
|
"""
|
|
batch_present = len(tensor.shape) == 3
|
|
if not batch_present:
|
|
tensor = tensor.unsqueeze(0)
|
|
|
|
q_range = 2**num_bits
|
|
num_groups = (
|
|
tensor.shape[1] * tensor.shape[2] // group_size
|
|
if group_size != -1
|
|
else tensor.shape[1]
|
|
)
|
|
"""Calculate a scaling factor per input group.
|
|
"""
|
|
input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
|
|
input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
|
|
input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
|
|
input_max_abs = torch.max(input_min.abs(), input_max.abs())
|
|
scale = input_max_abs * 2.0 / (q_range - 1)
|
|
"""Scale each input group, round to the nearest integer, shift
|
|
the range and truncate.
|
|
"""
|
|
scaled_input = input_flat / scale
|
|
scaled_input = scaled_input.round()
|
|
scaled_input += q_range // 2
|
|
scaled_input = scaled_input.clamp(0, q_range - 1)
|
|
|
|
scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
|
|
inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
|
|
inputs_q = inputs_q.contiguous()
|
|
|
|
if num_bits == 4:
|
|
"""Pack two 4-bit values into each byte.
|
|
"""
|
|
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xF)
|
|
inputs_q = inputs_q.reshape(
|
|
tensor.shape[0], tensor.shape[1] // 2, tensor.shape[2]
|
|
)
|
|
inputs_q = inputs_q.contiguous()
|
|
|
|
if not batch_present:
|
|
inputs_q = inputs_q.squeeze(0)
|
|
scale = scale.squeeze(0)
|
|
|
|
return inputs_q, scale
|
|
|
|
|
|
def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
"""Dequantize a tensor using per-group static scaling factors.
|
|
|
|
Args:
|
|
tensor: The input tensor.
|
|
scale: The tensor with per-group scale factors.
|
|
"""
|
|
batch_present = len(tensor.shape) == 3
|
|
if not batch_present:
|
|
tensor = tensor.unsqueeze(0)
|
|
scale = scale.unsqueeze(0)
|
|
|
|
num_groups = scale.size(1) * scale.size(2)
|
|
batch, input_dim, output_dim = tensor.shape
|
|
|
|
num_bits = 8 if input_dim == scale.size(1) else 4
|
|
q_range = 2**num_bits
|
|
if num_bits == 4:
|
|
input_dim *= 2
|
|
|
|
data = torch.empty(
|
|
(batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device
|
|
)
|
|
|
|
if num_bits == 8:
|
|
data.copy_(tensor)
|
|
data -= q_range // 2
|
|
else:
|
|
"""Unpack two 4-bit values from each byte.
|
|
"""
|
|
tensor = tensor.reshape(batch, input_dim, output_dim // 2)
|
|
for i in range(2):
|
|
data[:, :, i::2] = ((tensor << 4 * (1 - i)) >> 4).to(
|
|
torch.int8
|
|
) - q_range // 2
|
|
"""Scale each input group with its scaling factor.
|
|
"""
|
|
scale = scale.reshape(batch, num_groups, -1)
|
|
data = data.reshape(batch, num_groups, -1)
|
|
data = torch.mul(data, scale)
|
|
|
|
input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
|
|
if not batch_present:
|
|
input_deq = input_deq.squeeze(0)
|
|
|
|
return input_deq
|
|
|
|
|
|
def _get_perms():
|
|
perm = []
|
|
for i in range(32):
|
|
perm1 = []
|
|
col = i // 4
|
|
for block in [0, 1]:
|
|
for row in [
|
|
2 * (i % 4),
|
|
2 * (i % 4) + 1,
|
|
2 * (i % 4 + 4),
|
|
2 * (i % 4 + 4) + 1,
|
|
]:
|
|
perm1.append(16 * row + col + 8 * block)
|
|
for j in range(4):
|
|
perm.extend([p + 256 * j for p in perm1])
|
|
|
|
perm_arr = np.array(perm)
|
|
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
|
perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel()
|
|
perm_tensor = torch.from_numpy(perm_arr)
|
|
scale_perm = []
|
|
for i in range(8):
|
|
scale_perm.extend([i + 8 * j for j in range(8)])
|
|
scale_perm_single = []
|
|
for i in range(4):
|
|
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
|
return perm_tensor, scale_perm, scale_perm_single
|
|
|
|
|
|
_perm, _scale_perm, _scale_perm_single = _get_perms()
|
|
|
|
|
|
def pack_for_marlin(weight, scale, qbits):
|
|
batch = weight.shape[0]
|
|
|
|
n = weight.size(1)
|
|
k = weight.size(2)
|
|
groupsize = k // scale.size(2)
|
|
|
|
tile = 16
|
|
s = scale.permute(0, 2, 1) # transpose
|
|
w = weight.permute(0, 2, 1) # transpose
|
|
if groupsize != k:
|
|
w = w.reshape((batch, -1, groupsize, n))
|
|
w = w.permute(0, 2, 1, 3)
|
|
w = w.reshape((batch, groupsize, -1))
|
|
s = s.reshape((batch, 1, -1))
|
|
|
|
if groupsize != k:
|
|
w = w.reshape((batch, groupsize, -1, n))
|
|
w = w.permute(0, 2, 1, 3)
|
|
w = w.reshape((batch, k, n)).contiguous()
|
|
s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm]
|
|
else:
|
|
s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single]
|
|
s = s.reshape((batch, -1, n)).contiguous()
|
|
w = w.reshape((batch, k // tile, tile, n // tile, tile))
|
|
w = w.permute((0, 1, 3, 2, 4))
|
|
w = w.reshape((batch, k // tile, n * tile))
|
|
res = w
|
|
res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape)
|
|
if qbits == 4:
|
|
q = torch.zeros(
|
|
(batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device
|
|
)
|
|
for i in range(2):
|
|
q |= res[:, :, i::2] << 4 * i
|
|
q = q.reshape(batch, -1, n).contiguous()
|
|
else:
|
|
q = res.clone()
|
|
q[:, :, 2::8] = res[:, :, 4::8]
|
|
q[:, :, 3::8] = res[:, :, 5::8]
|
|
q[:, :, 4::8] = res[:, :, 2::8]
|
|
q[:, :, 5::8] = res[:, :, 3::8]
|
|
q = q.reshape(batch, -1, n).to(torch.int8).contiguous()
|
|
|
|
return q, s
|
|
|
|
|
|
def repack_8bit_into_32bit(input):
|
|
output = torch.zeros(
|
|
(input.shape[0], input.shape[1], input.shape[2] // 4),
|
|
dtype=torch.int32,
|
|
device=input.device,
|
|
)
|
|
for i in range(4):
|
|
output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i
|
|
|
|
return output
|
|
|
|
|
|
def repack_weights(qweight, scale, weight_bits):
|
|
batch_present = len(qweight.shape) == 3
|
|
if not batch_present:
|
|
qweight = qweight.unsqueeze(0)
|
|
scale = scale.unsqueeze(0)
|
|
|
|
if weight_bits == 4:
|
|
"""Unpack two 4-bit values from each byte.
|
|
"""
|
|
qweight_unpacked = torch.empty(
|
|
(qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]),
|
|
dtype=torch.uint8,
|
|
device=qweight.device,
|
|
)
|
|
for i in range(2):
|
|
qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape(
|
|
qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2
|
|
)
|
|
else:
|
|
qweight_unpacked = qweight
|
|
|
|
qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits)
|
|
"""Marlin kernels expect tensors in int32 format in a certain shape
|
|
"""
|
|
qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8))
|
|
qweight_reshaped = qweight_repacked.reshape(
|
|
qweight.shape[0], qweight.shape[2] // 16, -1
|
|
)
|
|
if not batch_present:
|
|
qweight_reshaped = qweight_reshaped.squeeze(0)
|
|
scale_packed = scale_packed.squeeze(0)
|
|
|
|
return qweight_reshaped, scale_packed
|
|
|
|
|
|
def init_workspace(device):
|
|
global workspace
|
|
if workspace is None:
|
|
workspace = marlin_make_workspace_new(device, 4)
|