mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:56:08 +08:00
611 lines
22 KiB
Python
611 lines
22 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
FusedMoEMethodBase)
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
UnquantizedLinearMethod,
|
|
set_weight_attrs)
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig)
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
|
|
class BitsAndBytesConfig(QuantizationConfig):
|
|
"""Config class for BitsAndBytes Quantization.
|
|
|
|
Reference: https://arxiv.org/abs/2305.14314
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
load_in_8bit: bool = False,
|
|
load_in_4bit: bool = True,
|
|
bnb_4bit_compute_dtype: str = "float32",
|
|
bnb_4bit_quant_storage: str = "uint8",
|
|
bnb_4bit_quant_type: str = "fp4",
|
|
bnb_4bit_use_double_quant: bool = False,
|
|
llm_int8_enable_fp32_cpu_offload: bool = False,
|
|
llm_int8_has_fp16_weight: bool = False,
|
|
llm_int8_skip_modules: Optional[list[str]] = None,
|
|
llm_int8_threshold: float = 6.0,
|
|
) -> None:
|
|
super().__init__()
|
|
self.load_in_8bit = load_in_8bit
|
|
self.load_in_4bit = load_in_4bit
|
|
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
|
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
|
|
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
|
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
|
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
|
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
|
self.llm_int8_skip_modules = llm_int8_skip_modules or []
|
|
self.llm_int8_threshold = llm_int8_threshold
|
|
|
|
if self.bnb_4bit_quant_storage not in ["uint8"]:
|
|
raise ValueError("Unsupported bnb_4bit_quant_storage: "
|
|
f"{self.bnb_4bit_quant_storage}")
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
|
|
f"load_in_4bit={self.load_in_4bit}, "
|
|
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
|
|
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
|
|
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
|
|
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
|
|
|
@classmethod
|
|
def get_name(self) -> QuantizationMethods:
|
|
return "bitsandbytes"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
|
return [torch.float32, torch.float16, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 70
|
|
|
|
@staticmethod
|
|
def get_config_filenames() -> list[str]:
|
|
return []
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "BitsAndBytesConfig":
|
|
|
|
def get_safe_value(config, keys, default_value=None):
|
|
try:
|
|
value = cls.get_from_keys(config, keys)
|
|
return value if value is not None else default_value
|
|
except ValueError:
|
|
return default_value
|
|
|
|
load_in_8bit = get_safe_value(config, ["load_in_8bit"],
|
|
default_value=False)
|
|
load_in_4bit = get_safe_value(config, ["load_in_4bit"],
|
|
default_value=True)
|
|
bnb_4bit_compute_dtype = get_safe_value(config,
|
|
["bnb_4bit_compute_dtype"],
|
|
default_value="float32")
|
|
bnb_4bit_quant_storage = get_safe_value(config,
|
|
["bnb_4bit_quant_storage"],
|
|
default_value="uint8")
|
|
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
|
|
default_value="fp4")
|
|
bnb_4bit_use_double_quant = get_safe_value(
|
|
config, ["bnb_4bit_use_double_quant"], default_value=False)
|
|
llm_int8_enable_fp32_cpu_offload = get_safe_value(
|
|
config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False)
|
|
llm_int8_has_fp16_weight = get_safe_value(config,
|
|
["llm_int8_has_fp16_weight"],
|
|
default_value=False)
|
|
llm_int8_skip_modules = get_safe_value(config,
|
|
["llm_int8_skip_modules"],
|
|
default_value=[])
|
|
llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"],
|
|
default_value=6.0)
|
|
|
|
return cls(
|
|
load_in_8bit=load_in_8bit,
|
|
load_in_4bit=load_in_4bit,
|
|
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
|
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
|
|
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
|
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
|
|
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,
|
|
llm_int8_has_fp16_weight=llm_int8_has_fp16_weight,
|
|
llm_int8_skip_modules=llm_int8_skip_modules,
|
|
llm_int8_threshold=llm_int8_threshold)
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
|
|
if isinstance(layer, LinearBase):
|
|
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
|
return UnquantizedLinearMethod()
|
|
return BitsAndBytesLinearMethod(self)
|
|
elif isinstance(layer, FusedMoE):
|
|
return BitsAndBytesMoEMethod(self)
|
|
return None
|
|
|
|
|
|
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
|
# Split the prefix into its dot-separated components
|
|
components = prefix.split('.')
|
|
|
|
# Check if any of the skip modules exactly matches any component
|
|
substr_check = any(module_name in components
|
|
for module_name in llm_int8_skip_modules)
|
|
|
|
# Allow certain layers to not be quantized
|
|
set_components = set(".".join(components[:i + 1])
|
|
for i in range(len(components)))
|
|
set_llm_int8_skip_modules = set(llm_int8_skip_modules)
|
|
prefix_check = len(set_llm_int8_skip_modules & set_components) != 0
|
|
|
|
return substr_check or prefix_check
|
|
|
|
|
|
def calculate_quant_ratio(dtype):
|
|
if dtype.is_floating_point:
|
|
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
|
else:
|
|
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
|
|
|
|
|
class BitsAndBytesLinearMethod(LinearMethodBase):
|
|
"""Linear method for BitsAndBytes.
|
|
|
|
Args:
|
|
quant_config: The BitsAndBytes quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: BitsAndBytesConfig):
|
|
try:
|
|
import bitsandbytes
|
|
if bitsandbytes.__version__ < "0.46.1":
|
|
raise ImportError("bitsandbytes version is wrong. Please "
|
|
"install bitsandbytes>=0.46.1.")
|
|
except ImportError as err:
|
|
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
|
"`pip install bitsandbytes>=0.46.1` to use "
|
|
"bitsandbytes quantizer.") from err
|
|
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int], input_size: int,
|
|
output_size: int, params_dtype: torch.dtype,
|
|
**extra_weight_attrs):
|
|
from bitsandbytes.nn import Int8Params
|
|
|
|
def create_qweight_for_8bit():
|
|
qweight = Int8Params(
|
|
data=torch.empty(sum(output_partition_sizes),
|
|
input_size_per_partition,
|
|
dtype=torch.int8),
|
|
has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight,
|
|
requires_grad=False)
|
|
set_weight_attrs(
|
|
qweight, {
|
|
"input_dim": 0,
|
|
"output_dim": 0,
|
|
"pack_factor": 1,
|
|
"use_bitsandbytes_8bit": True,
|
|
"generation": 0
|
|
})
|
|
return qweight
|
|
|
|
def create_qweight_for_4bit():
|
|
quant_ratio = calculate_quant_ratio(params_dtype)
|
|
|
|
total_size = input_size_per_partition * sum(output_partition_sizes)
|
|
if total_size % quant_ratio != 0:
|
|
raise ValueError(
|
|
"The input size is not aligned with the quantized "
|
|
"weight shape.")
|
|
|
|
qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio,
|
|
1,
|
|
dtype=torch.uint8),
|
|
requires_grad=False)
|
|
set_weight_attrs(
|
|
qweight, {
|
|
"input_dim": 0,
|
|
"output_dim": 0,
|
|
"pack_factor": quant_ratio,
|
|
"use_bitsandbytes_4bit": True
|
|
})
|
|
return qweight
|
|
|
|
if self.quant_config.load_in_8bit:
|
|
qweight = create_qweight_for_8bit()
|
|
else:
|
|
qweight = create_qweight_for_4bit()
|
|
# Enable parameters to have the same name as in the BNB
|
|
# checkpoint format.
|
|
layer.register_parameter("weight", qweight)
|
|
set_weight_attrs(qweight, extra_weight_attrs)
|
|
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
if self.quant_config.load_in_8bit:
|
|
return self._apply_8bit_weight(layer, x, bias)
|
|
else:
|
|
return self._apply_4bit_weight(layer, x, bias)
|
|
|
|
def _apply_8bit_weight(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
# only load the bitsandbytes module when needed
|
|
from bitsandbytes import MatmulLtState, matmul
|
|
|
|
original_type = x.dtype
|
|
original_shape = x.shape
|
|
reshape_after_matmul = False
|
|
if x.ndim > 2:
|
|
x = x.reshape(-1, x.size(-1))
|
|
reshape_after_matmul = True
|
|
bf_x = x.to(torch.bfloat16)
|
|
|
|
qweight = layer.weight
|
|
offsets = qweight.bnb_shard_offsets
|
|
quant_states = qweight.bnb_quant_state
|
|
matmul_states = qweight.matmul_state
|
|
generation = qweight.generation
|
|
|
|
out_dim_0 = x.shape[0]
|
|
out_dim_1 = sum(
|
|
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
|
out = torch.empty(out_dim_0,
|
|
out_dim_1,
|
|
dtype=torch.float16,
|
|
device=x.device)
|
|
|
|
current_index = 0
|
|
for i in range(len(quant_states)):
|
|
output_size = quant_states[i].shape[0]
|
|
|
|
# in profile_run or the first generation of inference,
|
|
# create new matmul_states
|
|
if generation == 0 or generation == 1:
|
|
matmul_states[i] = MatmulLtState()
|
|
matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]]
|
|
matmul_states[i].SCB = quant_states[i].to(x.device)
|
|
matmul_states[i].threshold = (
|
|
self.quant_config.llm_int8_threshold)
|
|
matmul_states[i].has_fp16_weights = (
|
|
self.quant_config.llm_int8_has_fp16_weight)
|
|
matmul_states[i].is_training = False
|
|
if matmul_states[i].threshold > 0.0 and not matmul_states[
|
|
i].has_fp16_weights:
|
|
matmul_states[i].use_pool = True
|
|
|
|
new_x = bf_x.unsqueeze(0)
|
|
|
|
out[:, current_index:current_index + output_size] = matmul(
|
|
new_x,
|
|
qweight[offsets[i]:offsets[i + 1]],
|
|
state=matmul_states[i])
|
|
|
|
current_index += output_size
|
|
|
|
# only update the matmul_states if it is not profile_run
|
|
if (generation > 0
|
|
and not self.quant_config.llm_int8_has_fp16_weight
|
|
and matmul_states[i].CB is not None
|
|
and matmul_states[i].CxB is not None):
|
|
del matmul_states[i].CB
|
|
qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB
|
|
|
|
out = out.to(original_type)
|
|
|
|
if reshape_after_matmul:
|
|
out = out.view(*original_shape[:-1], out.size(-1))
|
|
|
|
if bias is not None:
|
|
out += bias
|
|
|
|
qweight.generation += 1
|
|
|
|
return out
|
|
|
|
def _apply_4bit_weight(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
original_type = x.dtype
|
|
original_shape = x.shape
|
|
reshape_after_matmul = False
|
|
if x.ndim > 2:
|
|
x = x.reshape(-1, x.size(-1))
|
|
reshape_after_matmul = True
|
|
bf_x = x.to(torch.bfloat16)
|
|
|
|
qweight = layer.weight
|
|
quant_states = qweight.bnb_quant_state
|
|
offsets = qweight.bnb_shard_offsets
|
|
|
|
out_dim_0 = x.shape[0]
|
|
out_dim_1 = sum(
|
|
[quant_state[1].shape[0] for quant_state in quant_states.items()])
|
|
out = torch.empty(out_dim_0,
|
|
out_dim_1,
|
|
dtype=torch.bfloat16,
|
|
device=x.device)
|
|
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
|
out = out.to(original_type)
|
|
|
|
if reshape_after_matmul:
|
|
out = out.view(*original_shape[:-1], out.size(-1))
|
|
|
|
if bias is not None:
|
|
out += bias
|
|
|
|
return out
|
|
|
|
|
|
def _apply_bnb_4bit(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
out: torch.Tensor,
|
|
) -> None:
|
|
# only load the bitsandbytes module when needed
|
|
from bitsandbytes import matmul_4bit
|
|
quant_states = weight.bnb_quant_state
|
|
current_index = 0
|
|
for i in range(len(quant_states)):
|
|
output_size = quant_states[i].shape[0]
|
|
# It is more efficient to use out kwarg like
|
|
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
|
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
|
# Need to change after the bug is fixed.
|
|
out[:, current_index:current_index + output_size] = matmul_4bit(
|
|
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
|
current_index += output_size
|
|
|
|
|
|
def _apply_bnb_4bit_fake(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
offsets: torch.Tensor,
|
|
out: torch.Tensor,
|
|
) -> None:
|
|
return
|
|
|
|
|
|
try:
|
|
direct_register_custom_op(op_name="apply_bnb_4bit",
|
|
op_func=_apply_bnb_4bit,
|
|
mutates_args=["out"],
|
|
fake_impl=_apply_bnb_4bit_fake,
|
|
dispatch_key=current_platform.dispatch_key)
|
|
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
|
|
|
except AttributeError as error:
|
|
raise error
|
|
|
|
|
|
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|
"""MoE method for BitsAndBytes.
|
|
|
|
Args:
|
|
quant_config: The BitsAndBytes quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: BitsAndBytesConfig):
|
|
try:
|
|
import bitsandbytes
|
|
if bitsandbytes.__version__ < "0.45.3":
|
|
raise ImportError("bitsandbytes version is wrong. Please "
|
|
"install bitsandbytes>=0.45.3.")
|
|
except ImportError as err:
|
|
raise ImportError("Please install bitsandbytes>=0.45.3 via "
|
|
"`pip install bitsandbytes>=0.45.3` to use "
|
|
"bitsandbytes quantizer.") from err
|
|
self.topk_indices_dtype = None
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
if self.quant_config.load_in_8bit:
|
|
call_fun = self._create_weights_8bit
|
|
else:
|
|
call_fun = self._create_weights_4bit
|
|
call_fun(
|
|
layer,
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition,
|
|
params_dtype,
|
|
**extra_weight_attrs,
|
|
)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
|
|
if enable_eplb:
|
|
raise NotImplementedError(
|
|
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
|
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
indices_type=self.topk_indices_dtype)
|
|
if self.quant_config.load_in_8bit:
|
|
w13, w2 = self._apply_8bit_dequant(layer)
|
|
else:
|
|
w13, w2 = self._apply_4bit_dequnt(layer)
|
|
return fused_experts(
|
|
hidden_states=x,
|
|
w1=w13,
|
|
w2=w2,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
def _create_weights_4bit(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
quant_ratio = calculate_quant_ratio(params_dtype)
|
|
# Fused gate_up_proj (column parallel)
|
|
w13_total_size = (hidden_size * 2 *
|
|
intermediate_size_per_partition) // quant_ratio
|
|
w13_qweight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
w13_total_size,
|
|
1,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
layer.register_parameter("w13_weight", w13_qweight)
|
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
|
set_weight_attrs(
|
|
w13_qweight,
|
|
{
|
|
"num_experts":
|
|
num_experts,
|
|
"input_dim":
|
|
hidden_size,
|
|
"output_dim":
|
|
2 * intermediate_size_per_partition,
|
|
"experts_shape": (
|
|
num_experts,
|
|
intermediate_size_per_partition * 2,
|
|
hidden_size,
|
|
),
|
|
"pack_factor":
|
|
quant_ratio,
|
|
"use_bitsandbytes_4bit":
|
|
True,
|
|
},
|
|
)
|
|
# down_proj (row parallel)
|
|
w2_total_size = (hidden_size *
|
|
intermediate_size_per_partition) // quant_ratio
|
|
w2_qweight = torch.nn.Parameter(
|
|
torch.empty(
|
|
num_experts,
|
|
w2_total_size,
|
|
1,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(
|
|
w2_qweight,
|
|
{
|
|
"num_experts":
|
|
num_experts,
|
|
"input_dim":
|
|
intermediate_size_per_partition,
|
|
"output_dim":
|
|
hidden_size,
|
|
"experts_shape": (
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition,
|
|
),
|
|
"pack_factor":
|
|
quant_ratio,
|
|
"use_bitsandbytes_4bit":
|
|
True,
|
|
},
|
|
)
|
|
layer.register_parameter("w2_weight", w2_qweight)
|
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
|
|
|
def _create_weights_8bit(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
num_experts: int,
|
|
hidden_size: int,
|
|
intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
raise NotImplementedError
|
|
|
|
def _apply_4bit_dequnt(
|
|
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
|
from bitsandbytes.functional import dequantize_4bit
|
|
w13 = dequantize_4bit(
|
|
layer.w13_weight.reshape(-1, 1),
|
|
layer.w13_weight.bnb_quant_state,
|
|
)
|
|
w2 = dequantize_4bit(
|
|
layer.w2_weight.reshape(-1, 1),
|
|
layer.w2_weight.bnb_quant_state,
|
|
)
|
|
w13 = w13.reshape(layer.w13_weight.experts_shape)
|
|
w2 = w2.reshape(layer.w2_weight.experts_shape)
|
|
return w13, w2
|
|
|
|
def _apply_8bit_dequant(
|
|
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
|
raise NotImplementedError
|