mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 00:35:28 +08:00
[Quantization][1/N] MoE support BNB-Inflight Quantization (#20061)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
762be26a8e
commit
8020e98c9f
@ -14,7 +14,7 @@ from transformers import BitsAndBytesConfig
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ...utils import compare_two_settings, multi_gpu_test
|
||||
from ..utils import check_embeddings_close
|
||||
from ..utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
models_4bit_to_test = [
|
||||
("facebook/opt-125m", "quantize opt model inflight"),
|
||||
@ -26,6 +26,10 @@ models_4bit_to_embedding_test = [
|
||||
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
|
||||
]
|
||||
|
||||
models_4bit_to_moe_test = [
|
||||
("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"),
|
||||
]
|
||||
|
||||
models_pre_qaunt_4bit_to_test = [
|
||||
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
|
||||
'read pre-quantized 4-bit FP4 model'),
|
||||
@ -115,6 +119,35 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
||||
compare_two_settings(model_name, common_args, pp_args)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test)
|
||||
def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
))
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
enforce_eager=False) as llm:
|
||||
vllm_outputs = llm.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens=32,
|
||||
num_logprobs=5)
|
||||
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
transformers_outputs = llm.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens=32, num_logprobs=5)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=transformers_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="transformers",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description",
|
||||
@ -182,7 +215,8 @@ def validate_generated_texts(hf_runner,
|
||||
model_name,
|
||||
pre_quant=False,
|
||||
hf_model_kwargs=None,
|
||||
vllm_tp_size=1):
|
||||
vllm_tp_size=1,
|
||||
max_tokens=8):
|
||||
|
||||
# NOTE: run vLLM first, as it requires a clean process
|
||||
# when using distributed inference
|
||||
@ -190,7 +224,8 @@ def validate_generated_texts(hf_runner,
|
||||
quantization=None if pre_quant else 'bitsandbytes',
|
||||
tensor_parallel_size=vllm_tp_size,
|
||||
enforce_eager=False) as llm:
|
||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||
|
||||
vllm_outputs = llm.generate_greedy(prompts, max_tokens)
|
||||
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
@ -202,19 +237,17 @@ def validate_generated_texts(hf_runner,
|
||||
|
||||
# Run with HF runner
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
hf_outputs = llm.generate_greedy(prompts, 8)
|
||||
hf_outputs = llm.generate_greedy(prompts, max_tokens)
|
||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compare the generated strings
|
||||
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
|
||||
hf_str = hf_log["generated_text"]
|
||||
vllm_str = vllm_log["generated_text"]
|
||||
prompt = hf_log["prompt"]
|
||||
|
||||
assert hf_str == vllm_str, (f"Model: {model_name}"
|
||||
f"Mismatch between HF and vLLM outputs:\n"
|
||||
f"Prompt: {prompt}\n"
|
||||
|
||||
@ -883,13 +883,20 @@ class FusedMoE(torch.nn.Module):
|
||||
expert_data=expert_data,
|
||||
tp_rank=tp_rank)
|
||||
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
||||
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
||||
def _load_w13(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
shard_id: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim,
|
||||
shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
@ -998,6 +1005,27 @@ class FusedMoE(torch.nn.Module):
|
||||
param.data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
# Case for BitsAndBytes
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_dim = 0
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
expert_data.copy_(loaded_weight)
|
||||
elif shard_id in ("w1", "w3"):
|
||||
# BNB inflight quantization has already sharded the weights
|
||||
full_load = True
|
||||
self._load_w13(
|
||||
shard_id=shard_id,
|
||||
shard_dim=shard_dim,
|
||||
loaded_weight=loaded_weight,
|
||||
expert_data=expert_data,
|
||||
tp_rank=self.tp_rank,
|
||||
load_full=full_load,
|
||||
)
|
||||
return True if return_success else None
|
||||
|
||||
# is_transposed: if the dim to shard the weight
|
||||
# should be flipped. Required by GPTQ, compressed-tensors
|
||||
# should be whatever dimension intermediate_size_per_partition is
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
@ -120,12 +123,15 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
llm_int8_skip_modules=llm_int8_skip_modules,
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return BitsAndBytesMoEMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]):
|
||||
return substr_check or prefix_check
|
||||
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
@ -173,12 +186,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
**extra_weight_attrs):
|
||||
from bitsandbytes.nn import Int8Params
|
||||
|
||||
def calculate_quant_ratio(dtype):
|
||||
if dtype.is_floating_point:
|
||||
return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
else:
|
||||
return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits
|
||||
|
||||
def create_qweight_for_8bit():
|
||||
qweight = Int8Params(
|
||||
data=torch.empty(sum(output_partition_sizes),
|
||||
@ -394,3 +401,210 @@ try:
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
||||
|
||||
class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for BitsAndBytes.
|
||||
|
||||
Args:
|
||||
quant_config: The BitsAndBytes quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
||||
try:
|
||||
import bitsandbytes
|
||||
if bitsandbytes.__version__ < "0.45.3":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.45.3.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.45.3 via "
|
||||
"`pip install bitsandbytes>=0.45.3` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
self.topk_indices_dtype = None
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.load_in_8bit:
|
||||
call_fun = self._create_weights_8bit
|
||||
else:
|
||||
call_fun = self._create_weights_4bit
|
||||
call_fun(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `BitsAndBytesMoEMethod` yet.")
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
if self.quant_config.load_in_8bit:
|
||||
w13, w2 = self._apply_8bit_dequant(layer)
|
||||
else:
|
||||
w13, w2 = self._apply_4bit_dequnt(layer)
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=w13,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
def _create_weights_4bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
quant_ratio = calculate_quant_ratio(params_dtype)
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_total_size = (hidden_size * 2 *
|
||||
intermediate_size_per_partition) // quant_ratio
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w13_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
set_weight_attrs(
|
||||
w13_qweight,
|
||||
{
|
||||
"num_experts":
|
||||
num_experts,
|
||||
"input_dim":
|
||||
hidden_size,
|
||||
"output_dim":
|
||||
2 * intermediate_size_per_partition,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size,
|
||||
),
|
||||
"pack_factor":
|
||||
quant_ratio,
|
||||
"use_bitsandbytes_4bit":
|
||||
True,
|
||||
},
|
||||
)
|
||||
# down_proj (row parallel)
|
||||
w2_total_size = (hidden_size *
|
||||
intermediate_size_per_partition) // quant_ratio
|
||||
w2_qweight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
w2_total_size,
|
||||
1,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_qweight,
|
||||
{
|
||||
"num_experts":
|
||||
num_experts,
|
||||
"input_dim":
|
||||
intermediate_size_per_partition,
|
||||
"output_dim":
|
||||
hidden_size,
|
||||
"experts_shape": (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
),
|
||||
"pack_factor":
|
||||
quant_ratio,
|
||||
"use_bitsandbytes_4bit":
|
||||
True,
|
||||
},
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
def _create_weights_8bit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def _apply_4bit_dequnt(
|
||||
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from bitsandbytes.functional import dequantize_4bit
|
||||
w13 = dequantize_4bit(
|
||||
layer.w13_weight.reshape(-1, 1),
|
||||
layer.w13_weight.bnb_quant_state,
|
||||
)
|
||||
w2 = dequantize_4bit(
|
||||
layer.w2_weight.reshape(-1, 1),
|
||||
layer.w2_weight.bnb_quant_state,
|
||||
)
|
||||
w13 = w13.reshape(layer.w13_weight.experts_shape)
|
||||
w2 = w2.reshape(layer.w2_weight.experts_shape)
|
||||
return w13, w2
|
||||
|
||||
def _apply_8bit_dequant(
|
||||
self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -411,9 +412,33 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
elif (isinstance(module, FusedMoE)
|
||||
and hasattr(module.quant_method, "quant_config")):
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with FusedMoE is not "
|
||||
"supported yet.")
|
||||
if self.load_8bit:
|
||||
raise ValueError(
|
||||
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
||||
"supported yet.")
|
||||
# Get the corresponding weight name using module name and
|
||||
# get_expert_mapping.
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
for exp in expert_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts",
|
||||
"") + weight_name.removesuffix(".")
|
||||
self.target_modules.append(rep_name)
|
||||
|
||||
assert (self.target_modules
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
), "vLLM currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
@ -437,6 +462,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace(
|
||||
"experts", "") + weight_name.removesuffix(".")
|
||||
self.column_sharded_weights_modules.append(rep_name)
|
||||
|
||||
def _verify_model_compatibility(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
@ -490,34 +523,132 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
def _dequantize_dq(self, quant_states: Any):
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This
|
||||
comes at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
def _dequantize_single_state(quant_state):
|
||||
"""Helper function to dequantize a single QuantState object."""
|
||||
if not (isinstance(quant_state, QuantState)
|
||||
and quant_state.nested):
|
||||
return
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
absmax = dequantize_blockwise(quant_state.absmax,
|
||||
quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
# Ensure float32 dtype
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
if isinstance(quant_states, dict):
|
||||
for quant_state in quant_states.values():
|
||||
_dequantize_single_state(quant_state)
|
||||
else:
|
||||
_dequantize_single_state(quant_states)
|
||||
return quant_states
|
||||
|
||||
def _fuse_moe_quant_states(self, model: nn.Module,
|
||||
quant_states_dict: dict) -> dict:
|
||||
"""
|
||||
|
||||
This function consolidates individual expert quantization states into
|
||||
fused representations for w13 and w2.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
return dict()
|
||||
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
continue
|
||||
w1_states_lst = []
|
||||
w2_states_lst = []
|
||||
w3_states_lst = []
|
||||
for exp in expert_mapping:
|
||||
shard_id = exp[-1]
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
layer_prefix = name.split("experts")[0]
|
||||
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||
quant_state = self._dequantize_dq(
|
||||
quant_states_dict[weight_qual_name])
|
||||
if shard_id == "w1":
|
||||
w1_states_lst.append(quant_state)
|
||||
elif shard_id == "w2":
|
||||
w2_states_lst.append(quant_state)
|
||||
else:
|
||||
w3_states_lst.append(quant_state)
|
||||
del quant_states_dict[weight_qual_name]
|
||||
assert (len(w1_states_lst) == len(w2_states_lst) ==
|
||||
len(w3_states_lst))
|
||||
w13_absmax_lst = []
|
||||
w2_absmax_lst = []
|
||||
w13_total_dim0 = 0
|
||||
w2_total_dim0 = 0
|
||||
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
|
||||
w3_states_lst):
|
||||
assert w1_qs.shape == w3_qs.shape
|
||||
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||
# w1 and w3 are interleaved in storage
|
||||
w13_absmax_lst.append(w1_qs.absmax)
|
||||
w13_absmax_lst.append(w3_qs.absmax)
|
||||
w2_absmax_lst.append(w2_qs.absmax)
|
||||
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||
w2_total_dim0 += w2_qs.shape[0]
|
||||
|
||||
w13_absmax = torch.cat(w13_absmax_lst)
|
||||
w2_absmax = torch.cat(w2_absmax_lst)
|
||||
# Create fused quantization state for w13.
|
||||
w13_qs = QuantState(
|
||||
absmax=w13_absmax,
|
||||
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||
code=w1_states_lst[0].code,
|
||||
blocksize=w1_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w1_states_lst[0].dtype,
|
||||
)
|
||||
# Create fused quantization state for w2.
|
||||
w2_qs = QuantState(
|
||||
absmax=w2_absmax,
|
||||
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||
code=w2_states_lst[0].code,
|
||||
blocksize=w2_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w2_states_lst[0].dtype,
|
||||
)
|
||||
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||
# with the param in BitsAndBytesMoEMethod.
|
||||
w13_weight_name = name + ".w13_weight"
|
||||
w2_weight_name = name + ".w2_weight"
|
||||
expert_qs_dict[w13_weight_name] = w13_qs
|
||||
expert_qs_dict[w2_weight_name] = w2_qs
|
||||
return expert_qs_dict
|
||||
|
||||
def _stack_quantization_states(
|
||||
self, model: nn.Module,
|
||||
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
@ -558,14 +689,20 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
||||
quant_state_dict[non_stacked_param_name])
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(self, model: nn.Module,
|
||||
stacked_quant_state_dict: dict) -> None:
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
param_dict = dict(model.named_parameters())
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
dequantize_dq(quant_states)
|
||||
self._dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
if not isinstance(quant_states, dict):
|
||||
continue
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
@ -585,29 +722,40 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
expert_quant_state_dict = self._fuse_moe_quant_states(
|
||||
model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = self._stack_quantization_states(
|
||||
model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = {
|
||||
**expert_quant_state_dict,
|
||||
**stacked_quant_state_dict
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
|
||||
def dequantize_dq(quant_states: dict) -> None:
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This comes
|
||||
at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
for _, quant_state in quant_states.items():
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
if isinstance(quant_state, QuantState) and quant_state.nested:
|
||||
absmax = dequantize_blockwise(quant_state.absmax,
|
||||
quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
@ -330,6 +330,15 @@ class OlmoeModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -341,14 +350,6 @@ class OlmoeModel(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
@ -379,7 +380,7 @@ class OlmoeModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
for mapping in self.get_expert_mapping():
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@ -425,6 +426,17 @@ class OlmoeModel(nn.Module):
|
||||
|
||||
|
||||
class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -466,3 +478,6 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
@ -516,6 +516,14 @@ class PhiMoEModel(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
num_experts=self.config.num_local_experts,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -672,3 +680,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
@ -391,6 +391,15 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -402,14 +411,6 @@ class Qwen2MoeModel(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
@ -441,11 +442,13 @@ class Qwen2MoeModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
for mapping in self.get_expert_mapping():
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if "layers.13.mlp.experts.w2_weight" in name:
|
||||
pass
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
@ -493,6 +496,17 @@ class Qwen2MoeModel(nn.Module):
|
||||
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -538,3 +552,6 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
@ -375,6 +375,15 @@ class Qwen3MoeModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
return FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
@ -393,12 +402,7 @@ class Qwen3MoeModel(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
@ -539,3 +543,6 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user